DocumentationAPI Reference📓 Tutorials🧑‍🍳 Cookbook🤝 Integrations💜 Discord🎨 Studio (Waitlist)
Documentation

SagemakerGenerator

This component enables text generation using LLMs deployed on Amazon Sagemaker.

NameSagemakerGenerator
Sourcehttps://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker
Most common position in a pipelineAfter a PromptBuilder
Mandatory input variables“prompt”: A string containing the prompt for the LLM
Output variables“replies”: A list of strings with all the replies generated by the LLM

”meta”: A list of dictionaries with the metadata associated with each reply, such as token count, finish reason, and so on

SagemakerGenerator allows you to make use of models deployed on AWS SageMaker.

Parameters Overview

SagemakerGenerator needs AWS credentials to work. Set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables.

You also need to specify your Sagemaker endpoint at initialization time for the component to work. Pass the endpoint name to the model parameter like this:

generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16")

Additionally, you can pass any text generation parameters valid for your specific model directly to SagemakerGenerator using the generation_kwargs parameter, both at initialization and to run() method.

If your model also needs custom attributes, pass those as a dictionary at initialization time by setting the aws_custom_attributes parameter.

One notable family of models that needs these custom parameters is Llama2, which needs to be initialized with {"accept_eula": True} :

generator = SagemakerGenerator(
	model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b",
	aws_custom_attributes={"accept_eula": True}
)

Usage

You need to install amazon-sagemaker-haystack package to use the SagemakerGenerator:

pip install amazon-sagemaker-haystack

On its own

Basic usage:

from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator

client = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16")
client.warm_up()
response = client.run("Briefly explain what NLP is in one sentence.")
print(response)

>>> {'replies': ["Natural Language Processing (NLP) is a subfield of artificial intelligence and computational linguistics that focuses on the interaction between computers and human languages..."],
 'metadata': [{}]}

In a pipeline

In a RAG pipeline:

from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator
from haystack import Pipeline
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.builders import PromptBuilder

template = """
Given the following information, answer the question.

Context: 
{% for document in documents %}
    {{ document.content }}
{% endfor %}

Question: What's the official language of {{ country }}?
"""
pipe = Pipeline()

pipe.add_component("retriever", InMemoryBM25Retriever(document_store=docstore))
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("llm", SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16"))
pipe.connect("retriever", "prompt_builder.documents")
pipe.connect("prompt_builder", "llm")

pipe.run({
    "prompt_builder": {
        "country": "France"
    }
})

Related Links

Check out the API reference in the GitHub repo or in our docs: