GradientGenerator
GradientGenerator
enables text generation with LLMs deployed on the Gradient AI platform
Name | GradientGenerator |
Source | https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/gradient |
Most common position in a pipeline | After a PromptBuilder |
Mandatory input variables | “prompt”: A string containing the prompt for the LLM |
Output variables | “replies”: A list of strings with all the replies generated by the LLM |
GradientGenerator
enables text generation using generative models hosted by Gradient AI. You can either use one of the base models provided through the platform, or models that you’ve fine-tuned and are hosting through the platform.
You can currently use the following models, with more becoming available soon. Check out the Gradient documentation for the full list.
bloom-560m
llama2-7b-chat
nous-hermes2
For an example showcasing this component, check out this article and the related 🧑🍳 Cookbook.
Parameters Overview
GradientGenerator
needs an access_token
and workspace_id
. It also needs either a base_model_slug
or model_adapter_id
. You can provide these in one of the following ways:
For the access token
and workspace_id
, do one of the following:
- Provide the
access_token
andworkspace_id
init parameter. - Set
GRADIENT_ACCESS_TOKEN
andGRADIENT_WORKSPACE_ID
environment variables.
For the model you would like to use, do one of the following :
- Provide the
base_model_slug
. Check the available base models on the Gradient documentation - If you’ve deployed a model (fine-tuned or not) on Gradient, provide the
model_adapter_id
for that model.
Usage
You need to install gradient-haystack
package to use the GradientGenerator
:
pip install gradient-haystack
On its own
Basic usage (with a base model). You can replace the base_model_slug
with a model_adapter_id
to use your own deployed models in your Gradient workspace:
import os
from haystack_integrations.components.generators.gradient import GradientGenerator
os.environ["GRADIENT_ACCESS_TOKEN"]="YOUR_GRADIENT_ACCESS_TOKEN"
os.environ["GRADIENT_WORKSPACE_ID"]="GRADIENT_WORKSPACE_ID"
generator = GradientGenerator(base_model="llama2-7b-chat",
max_generated_token_count=350)
generator.warm_up()
generator.run(prompt="What is the meaning of life?")
In a pipeline
Here’s an example of this generator in a RAG Pipeline. In this Pipeline, we are using the GradientTextEmbedder
and the GradientDocumentEmbedder
as well. You can replace these with any other embedder. It assumes that you have an InMemoryDocumentStore
that has Documents in it:
import os
from haystack import Pipeline
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import PromptBuilder
from haystack_integrations.components.embedders.gradient import GradientTextEmbedder
from haystack_integrations.components.generators.gradient import GradientGenerator
from haystack.document_stores.in_memory import InMemoryDocumentStore
os.environ["GRADIENT_ACCESS_TOKEN"]="YOUR_GRADIENT_ACCESS_TOKEN"
os.environ["GRADIENT_WORKSPACE_ID"]="GRADIENT_WORKSPACE_ID"
document_store = InMemoryDocumentStore()
prompt = """ Answer the query, based on the
content in the documents.
Documents:
{% for doc in documents %}
{{doc.content}}
{% endfor %}
Query: {{query}}
"""
text_embedder = GradientTextEmbedder()
retriever = InMemoryEmbeddingRetriever(document_store=document_store)
prompt_builder = PromptBuilder(template=prompt)
generator = GradientGenerator(model_adapter_id="YOUR_MODEL_ADAPTER_ID",
max_generated_token_count=350)
rag_pipeline = Pipeline()
rag_pipeline.add_component(instance=text_embedder, name="text_embedder")
rag_pipeline.add_component(instance=retriever, name="retriever")
rag_pipeline.add_component(instance=prompt_builder, name="prompt_builder")
rag_pipeline.add_component(instance=generator, name="generator")
rag_pipeline.connect("text_embedder", "retriever")
rag_pipeline.connect("retriever.documents", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "generator")
question = "What are the steps for creating a custom component?"
result = rag_pipeline.run(data={"text_embedder":{"text": question},
"prompt_builder":{"query": question}})
Updated 6 months ago