文档API 参考📓 教程🧑‍🍳 食谱🤝 集成💜 Discord🎨 Studio
API 参考

Amazon Sagemaker

Amazon Sagemaker 集成,用于 Haystack

模块 haystack_integrations.components.generators.amazon_sagemaker.sagemaker

SagemakerGenerator

通过 Amazon Sagemaker 实现文本生成。

SagemakerGenerator 支持托管在 SageMaker 推理端点的超大语言模型 (LLMs)。有关如何将模型部署到 SageMaker 的指导,请参阅 SageMaker JumpStart 基础模型文档

使用示例

# Make sure your AWS credentials are set up correctly. You can use environment variables or a shared credentials
# file. Then you can use the generator as follows:
from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator

generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-bf16")
response = generator.run("What's Natural Language Processing? Be brief.")
print(response)
>>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
>>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
>>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]}

SagemakerGenerator.__init__

def __init__(
        model: str,
        aws_access_key_id: Optional[Secret] = Secret.from_env_var(
            ["AWS_ACCESS_KEY_ID"], strict=False),
        aws_secret_access_key: Optional[Secret] = Secret.
    from_env_var(  # noqa: B008
        ["AWS_SECRET_ACCESS_KEY"], strict=False),
        aws_session_token: Optional[Secret] = Secret.from_env_var(
            ["AWS_SESSION_TOKEN"], strict=False),
        aws_region_name: Optional[Secret] = Secret.from_env_var(
            ["AWS_DEFAULT_REGION"], strict=False),
        aws_profile_name: Optional[Secret] = Secret.from_env_var(
            ["AWS_PROFILE"], strict=False),
        aws_custom_attributes: Optional[Dict[str, Any]] = None,
        generation_kwargs: Optional[Dict[str, Any]] = None)

初始化 SageMaker 会话。

参数:

  • aws_access_key_id: AWS 访问密钥 ID 的Secret
  • aws_secret_access_key: AWS 密钥访问的Secret
  • aws_session_token: AWS 会话令牌的Secret
  • aws_region_name: AWS 区域名称的Secret。如果未提供,将使用默认区域。
  • aws_profile_name: AWS 配置文件名称的Secret。如果未提供,将使用默认配置文件。
  • model: SageMaker 模型端点的名称。
  • aws_custom_attributes: 要传递给 SageMaker 的自定义属性,例如{"accept_eula": True}(对于 Llama-2 模型)。
  • generation_kwargs: 用于文本生成的其他关键字参数。支持的参数列表请参见您的模型文档页面,例如 HuggingFace 模型的文档: https://hugging-face.cn/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model

具体来说,Llama-2 模型支持以下推理载荷参数:

  • max_new_tokens: 模型生成文本,直到输出长度(不包括输入上下文长度)达到max_new_tokens。如果指定,必须是正整数。
  • temperature: 控制输出的随机性。较高的温度会产生低概率词语的输出序列,而较低的温度会产生高概率词语的输出序列。如果temperature=0,则会产生贪婪解码。如果指定,必须是正浮点数。
  • top_p: 在文本生成的每一步,从累积概率最小的词语集合中进行采样top_p。如果指定,必须是介于 0 和 1 之间的浮点数。
  • return_full_text: 如果True,输入文本将包含在生成的输出文本中。如果指定,必须是布尔值。默认值为False.

SagemakerGenerator.to_dict

def to_dict() -> Dict[str, Any]

将组件序列化为字典。

返回值:

包含序列化数据的字典。

SagemakerGenerator.from_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SagemakerGenerator"

从字典反序列化组件。

参数:

  • data: 要反序列化的字典。

返回值:

反序列化后的组件。

SagemakerGenerator.run

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
def run(
    prompt: str,
    generation_kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Union[List[str], List[Dict[str, Any]]]]

根据提供的提示和生成参数调用文本生成推理。

参数:

  • prompt: 用于文本生成的字符串提示。
  • generation_kwargs:用于文本生成的附加关键字参数。这些参数可能会覆盖在streaming_callback

引发:

  • ValueError: 如果模型响应类型不是字典列表或单个字典。
  • SagemakerNotReadyError: 如果 SageMaker 模型尚未准备好接受请求。
  • SagemakerInferenceError: 如果 SageMaker 推理返回错误。

返回值:

包含以下键的字典

  • replies: 包含生成响应的字符串列表
  • meta: 包含每个响应元数据的字典列表。