SagemakerGenerator
This component enables text generation using LLMs deployed on Amazon Sagemaker.
Name | SagemakerGenerator |
Source | https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker |
Most common position in a pipeline | After a PromptBuilder |
Mandatory input variables | “prompt”: A string containing the prompt for the LLM |
Output variables | “replies”: A list of strings with all the replies generated by the LLM ”meta”: A list of dictionaries with the metadata associated with each reply, such as token count, finish reason, and so on |
SagemakerGenerator
allows you to make use of models deployed on AWS SageMaker.
Parameters Overview
SagemakerGenerator
needs AWS credentials to work. Set the AWS_ACCESS_KEY_ID
and AWS_SECRET_ACCESS_KEY
environment variables.
You also need to specify your Sagemaker endpoint at initialization time for the component to work. Pass the endpoint name to the model
parameter like this:
generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16")
Additionally, you can pass any text generation parameters valid for your specific model directly to SagemakerGenerator
using the generation_kwargs
parameter, both at initialization and to run()
method.
If your model also needs custom attributes, pass those as a dictionary at initialization time by setting the aws_custom_attributes
parameter.
One notable family of models that needs these custom parameters is Llama2, which needs to be initialized with {"accept_eula": True}
:
generator = SagemakerGenerator(
model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b",
aws_custom_attributes={"accept_eula": True}
)
Usage
You need to install amazon-sagemaker-haystack
package to use the SagemakerGenerator
:
pip install amazon-sagemaker-haystack
On its own
Basic usage:
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
client = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16")
client.warm_up()
response = client.run("Briefly explain what NLP is in one sentence.")
print(response)
>>> {'replies': ["Natural Language Processing (NLP) is a subfield of artificial intelligence and computational linguistics that focuses on the interaction between computers and human languages..."],
'metadata': [{}]}
In a pipeline
In a RAG pipeline:
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
from haystack import Pipeline
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.builders import PromptBuilder
template = """
Given the following information, answer the question.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: What's the official language of {{ country }}?
"""
pipe = Pipeline()
pipe.add_component("retriever", InMemoryBM25Retriever(document_store=docstore))
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("llm", SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16"))
pipe.connect("retriever", "prompt_builder.documents")
pipe.connect("prompt_builder", "llm")
pipe.run({
"prompt_builder": {
"country": "France"
}
})
Updated 6 months ago
Check out the API reference in the GitHub repo or in our docs: