文档API 参考📓 教程🧑‍🍳 食谱🤝 集成💜 Discord🎨 Studio
API 参考

Routers (路由器)

路由器是一组组件,用于将查询或文档路由到最能处理它们的其他组件。

模块 conditional_router

NoRouteSelectedException

在 ConditionalRouter 中未选择路由时引发的异常。

RouteConditionException

在 ConditionalRouter 中解析或评估条件表达式时出错时引发的异常。

ConditionalRouter

根据特定条件路由数据。

您在名为routes 的字典列表中定义这些条件。此列表中的每个字典代表一个路由。每个路由都有这四个元素

  • condition:一个 Jinja2 字符串表达式,用于确定是否选择路由。
  • output:一个 Jinja2 表达式,定义路由的输出值。
  • output_type:输出数据的类型(例如,str, list[int]).
  • output_name:您想用来发布output 的名称。此名称用于将路由器连接到管道中的其他组件。

使用示例

from haystack.components.routers import ConditionalRouter

routes = [
    {
        "condition": "{{streams|length > 2}}",
        "output": "{{streams}}",
        "output_name": "enough_streams",
        "output_type": list[int],
    },
    {
        "condition": "{{streams|length <= 2}}",
        "output": "{{streams}}",
        "output_name": "insufficient_streams",
        "output_type": list[int],
    },
]
router = ConditionalRouter(routes)
# When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result = router.run(**kwargs)
assert result == {"enough_streams": [1, 2, 3]}

在此示例中,我们配置了两个路由。如果流计数超过两个,则第一个路由将“streams”值发送到“enough_streams”。如果流数小于或等于两个,则第二个路由将“streams”导向“insufficient_streams”。

在管道设置中,路由器使用输出名称连接到其他组件。例如,“enough_streams”可能连接到处理流的组件,而“insufficient_streams”可能连接到获取更多流的组件。

这是一个使用ConditionalRouter 的管道,并根据获取的流数量将获取的ByteStreams 路由到不同的组件。

from haystack import Pipeline
from haystack.dataclasses import ByteStream
from haystack.components.routers import ConditionalRouter

routes = [
    {
        "condition": "{{streams|length > 2}}",
        "output": "{{streams}}",
        "output_name": "enough_streams",
        "output_type": list[ByteStream],
    },
    {
        "condition": "{{streams|length <= 2}}",
        "output": "{{streams}}",
        "output_name": "insufficient_streams",
        "output_type": list[ByteStream],
    },
]

pipe = Pipeline()
pipe.add_component("router", router)
...
pipe.connect("router.enough_streams", "some_component_a.streams")
pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
...

ConditionalRouter.__init__

def __init__(routes: list[Route],
             custom_filters: Optional[dict[str, Callable]] = None,
             unsafe: bool = False,
             validate_output_type: bool = False,
             optional_variables: Optional[list[str]] = None)

使用详细说明路由条件的路由列表初始化ConditionalRouter

参数:

  • routes:一个字典列表,每个字典定义一个路由。每个路由都有这四个元素
  • condition:一个 Jinja2 字符串表达式,用于确定是否选择路由。
  • output:一个 Jinja2 表达式,定义路由的输出值。
  • output_type:输出数据的类型(例如,str, list[int]).
  • output_name:您想用来发布output 的名称。此名称用于将路由器连接到管道中的其他组件。
  • custom_filters:用于条件表达式中的自定义 Jinja2 过滤器字典。例如,传递{"my_filter": my_filter_fcn},其中
  • my_filter 是自定义过滤器的名称。
  • my_filter_fcn 是一个可调用对象,它接受my_var:str 并返回my_var[:3]. {{ my_var|my_filter }} 可以在路由条件表达式内部使用"condition": "{{ my_var|my_filter == 'foo' }}".
  • unsafe:在 Jinja 模板中启用任意代码执行。这只应在您信任模板来源时使用,因为它可能导致远程代码执行。
  • validate_output_type:启用路由输出的验证。如果路由输出与声明的类型不匹配,则会引发 ValueError。
  • optional_variables:在路由条件和输出中可选的变量名称列表。如果在运行时未提供这些变量,它们将被设置为None。这允许您编写可以优雅地处理缺失输入而不会引发错误的路由。

在管道中使用默认回退路由的示例

from haystack import Pipeline
from haystack.components.routers import ConditionalRouter

routes = [
    {
        "condition": '{{ path == "rag" }}',
        "output": "{{ question }}",
        "output_name": "rag_route",
        "output_type": str
    },
    {
        "condition": "{{ True }}",  # fallback route
        "output": "{{ question }}",
        "output_name": "default_route",
        "output_type": str
    }
]

router = ConditionalRouter(routes, optional_variables=["path"])
pipe = Pipeline()
pipe.add_component("router", router)

# When 'path' is provided in the pipeline:
result = pipe.run(data={"router": {"question": "What?", "path": "rag"}})
assert result["router"] == {"rag_route": "What?"}

# When 'path' is not provided, fallback route is taken:
result = pipe.run(data={"router": {"question": "What?"}})
assert result["router"] == {"default_route": "What?"}

此模式在以下情况特别有用

  • 您希望在某些输入缺失时提供默认/回退行为
  • 某些变量仅用于特定的路由条件
  • 您正在构建灵活的管道,其中并非所有输入都保证存在

ConditionalRouter.to_dict

def to_dict() -> dict[str, Any]

将组件序列化为字典。

返回值:

包含序列化数据的字典。

ConditionalRouter.from_dict

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ConditionalRouter"

从字典反序列化组件。

参数:

  • data: 要反序列化的字典。

返回值:

反序列化后的组件。

ConditionalRouter.run

def run(**kwargs)

执行路由逻辑。

通过按列出的顺序评估每个路由的指定布尔条件表达式来执行路由逻辑。该方法将数据流导向第一个condition 为 True 的路由中指定的输出。

参数:

  • kwargs:路由中condition 中使用的所有变量。当组件在管道中使用时,这些变量从前一个组件的输出中传递。

引发:

  • NoRouteSelectedException:如果路由中没有condition 为 `True`。
  • RouteConditionException:如果在解析或评估路由中的condition 表达式时出错。
  • ValueError:如果启用了类型验证且路由类型与实际值类型不匹配。

返回值:

一个字典,其中键是所选路由的output_name,值是所选路由的output

模块 document_length_router

DocumentLengthRouter

根据content 字段的长度对文档进行分类,并将其路由到适当的输出。

DocumentLengthRouter 的常见用例是处理从包含非文本内容(如扫描页面或图像)的 PDF 中获取的文档。此组件可以检测空文档或内容较少的文档,并将其路由到执行 OCR、生成字幕或计算图像嵌入的组件。

使用示例

from haystack.components.routers import DocumentLengthRouter
from haystack.dataclasses import Document

docs = [
    Document(content="Short"),
    Document(content="Long document "*20),
]

router = DocumentLengthRouter(threshold=10)

result = router.run(documents=docs)
print(result)

# {
#     "short_documents": [Document(content="Short", ...)],
#     "long_documents": [Document(content="Long document ...", ...)],
# }

DocumentLengthRouter.__init__

def __init__(*, threshold: int = 10) -> None

初始化 DocumentLengthRouter 组件。

参数:

  • threshold:文档中字符数的阈值content 字段。当content 为 None 或其字符数小于或等于阈值时,文档将被路由到short_documents 输出。否则,它们将被路由到long_documents 输出。要仅将内容为 None 的文档路由到short_documents,请将阈值设置为负数。

DocumentLengthRouter.run

@component.output_types(short_documents=list[Document],
                        long_documents=list[Document])
def run(documents: list[Document]) -> dict[str, list[Document]]

根据content 字段。

参数:

  • documents 长度将输入文档分类到组中:要分类的文档列表。

返回值:

包含以下键的字典

  • short_documents:一个文档列表,其中content 为 None,或长度小于或等于阈值。content 的长度小于或等于阈值。
  • long_documents:一个文档列表,其中content 的长度大于阈值。

模块 document_type_router

DocumentTypeRouter

按 MIME 类型路由文档。

DocumentTypeRouter 用于根据文档的 MIME 类型在管道中动态路由文档。它支持精确的 MIME 类型匹配和正则表达式模式。

MIME 类型可以直接从文档元数据中提取,或使用标准或用户提供的 MIME 类型映射从文件路径中推断。

使用示例

from haystack.components.routers import DocumentTypeRouter
from haystack.dataclasses import Document

docs = [
    Document(content="Example text", meta={"file_path": "example.txt"}),
    Document(content="Another document", meta={"mime_type": "application/pdf"}),
    Document(content="Unknown type")
]

router = DocumentTypeRouter(
    mime_type_meta_field="mime_type",
    file_path_meta_field="file_path",
    mime_types=["text/plain", "application/pdf"]
)

result = router.run(documents=docs)
print(result)

预期输出

{
    "text/plain": [Document(...)],
    "application/pdf": [Document(...)],
    "unclassified": [Document(...)]
}

DocumentTypeRouter.__init__

def __init__(*,
             mime_types: list[str],
             mime_type_meta_field: Optional[str] = None,
             file_path_meta_field: Optional[str] = None,
             additional_mimetypes: Optional[dict[str, str]] = None) -> None

初始化 DocumentTypeRouter 组件。

参数:

  • mime_types:一个 MIME 类型或正则表达式模式列表,用于分类输入文档。(例如["text/plain", "audio/x-wav", "image/jpeg"]).
  • mime_type_meta_field:可选的元数据字段名称,用于存储 MIME 类型。
  • file_path_meta_field:可选的元数据字段名称,用于存储文件路径。如果mime_type_meta_field 未提供或文档中缺失,则用于推断 MIME 类型。
  • additional_mimetypes:可选的字典,将 MIME 类型映射到文件扩展名,以增强或覆盖标准mimetypes 模块。在处理不常见或自定义文件类型时很有用。例如{"application/vnd.custom-type": ".custom"}.

引发:

  • ValueError: 如果mime_types 为空,或者mime_type_meta_fieldfile_path_meta_field 都未提供。

DocumentTypeRouter.run

def run(documents: list[Document]) -> dict[str, list[Document]]

根据其 MIME 类型将输入文档分类到组中。

MIME 类型可以直接在文档元数据中提供,或使用标准 Python 从文件路径中派生mimetypes 模块和自定义映射。

参数:

  • documents 长度将输入文档分类到组中:要分类的文档列表。

返回值:

一个字典,其中键是 MIME 类型(或"unclassified"),值是文档列表。

模块 file_type_router

FileTypeRouter

按 MIME 类型对文件或字节流进行分类,有助于基于上下文的路由。

FileTypeRouter 支持精确的 MIME 类型匹配和正则表达式模式。

对于文件路径,MIME 类型来自扩展名,而字节流使用元数据。您可以在mime_types 参数中使用正则表达式模式来设置广泛的类别(例如“audio/*”或“text/*”)或特定类型。没有正则表达式模式的 MIME 类型被视为精确匹配。

使用示例

from haystack.components.routers import FileTypeRouter
from pathlib import Path

# For exact MIME type matching
router = FileTypeRouter(mime_types=["text/plain", "application/pdf"])

# For flexible matching using regex, to handle all audio types
router_with_regex = FileTypeRouter(mime_types=[r"audio/.*", r"text/plain"])

sources = [Path("file.txt"), Path("document.pdf"), Path("song.mp3")]
print(router.run(sources=sources))
print(router_with_regex.run(sources=sources))

# Expected output:
# {'text/plain': [
#   PosixPath('file.txt')], 'application/pdf': [PosixPath('document.pdf')], 'unclassified': [PosixPath('song.mp3')
# ]}
# {'audio/.*': [
#   PosixPath('song.mp3')], 'text/plain': [PosixPath('file.txt')], 'unclassified': [PosixPath('document.pdf')
# ]}

FileTypeRouter.__init__

def __init__(mime_types: list[str],
             additional_mimetypes: Optional[dict[str, str]] = None,
             raise_on_failure: bool = False)

初始化 FileTypeRouter 组件。

参数:

  • mime_types:一个 MIME 类型或正则表达式模式列表,用于分类输入文件或字节流。(例如["text/plain", "audio/x-wav", "image/jpeg"]).
  • additional_mimetypes:一个字典,包含要添加到 mimetypes 包的 MIME 类型,以防止不支持或非原生包未分类。(例如{"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}).
  • raise_on_failure:如果为 True,则在文件路径不存在时引发 FileNotFoundError。如果为 False(默认),则仅在文件路径不存在时发出警告。

FileTypeRouter.to_dict

def to_dict() -> dict[str, Any]

将组件序列化为字典。

返回值:

包含序列化数据的字典。

FileTypeRouter.from_dict

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "FileTypeRouter"

从字典反序列化组件。

参数:

  • data: 要反序列化的字典。

返回值:

反序列化后的组件。

FileTypeRouter.run

def run(
    sources: list[Union[str, Path, ByteStream]],
    meta: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None
) -> dict[str, list[Union[ByteStream, Path]]]

根据 MIME 类型对文件或字节流进行分类。

参数:

  • sources:要分类的文件路径或字节流列表。
  • meta:可选的元数据,用于附加到源。提供时,源内部转换为 ByteStream 对象并添加元数据。此值可以是字典列表或单个字典。如果它是单个字典,其内容将添加到所有 ByteStream 对象的元数据中。如果它是列表,其长度必须与源的数量匹配,因为它们是打包在一起的。

返回值:

一个字典,其中键是 MIME 类型,值是数据源列表。可能返回两个额外的键"unclassified" 当源的 MIME 类型不匹配任何模式时,以及"failed" 当源无法处理时(例如,不存在的文件路径)。

模块 llm_messages_router

LLMMessagesRouter

使用生成式语言模型对聊天消息进行分类,并将其路由到不同的连接。

This component can be used with general-purpose LLMs and with specialized LLMs for moderation like Llama Guard.

### Usage example
```python
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.components.routers.llm_messages_router import LLMMessagesRouter
from haystack.dataclasses import ChatMessage

# initialize a Chat Generator with a generative model for moderation
chat_generator = HuggingFaceAPIChatGenerator(
    api_type="serverless_inference_api",
    api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"},
)

router = LLMMessagesRouter(chat_generator=chat_generator,
                            output_names=["unsafe", "safe"],
                            output_patterns=["unsafe", "safe"])


print(router.run([ChatMessage.from_user("How to rob a bank?")]))

# {
#     'chat_generator_text': 'unsafe

S2', # 'unsafe': [ # ChatMessage( # _role=<ChatRole.USER: 'user'>, # _content=[TextContent(text='How to rob a bank?')], # _name=None, # _meta={} # ) # ] # } ```

LLMMessagesRouter.__init__

def __init__(chat_generator: ChatGenerator,
             output_names: list[str],
             output_patterns: list[str],
             system_prompt: Optional[str] = None)

初始化 LLMMessagesRouter 组件。

参数:

  • chat_generator:表示 LLM 的 ChatGenerator 实例。
  • output_names:输出连接名称列表。这些可用于将路由器连接到其他组件。
  • output_patterns:要与 LLM 输出匹配的正则表达式列表。每个模式对应一个输出名称。模式按顺序评估。使用审核模型时,请参阅模型卡以了解预期输出。
  • system_prompt:用于自定义 LLM 行为的可选系统提示。对于审核模型,请参阅模型卡以了解支持的自定义选项。

引发:

  • ValueError:如果 output_names 和 output_patterns 不是相同长度的非空列表。

LLMMessagesRouter.warm_up

def warm_up()

预热底层 LLM。

LLMMessagesRouter.run

def run(messages: list[ChatMessage]
        ) -> dict[str, Union[str, list[ChatMessage]]]

根据 LLM 输出对消息进行分类,并将其路由到适当的输出连接。

参数:

  • messages:要路由的 ChatMessages 列表。仅支持用户和助手消息。

引发:

  • ValueError:如果 messages 是空列表或包含具有不支持角色的消息。
  • RuntimeError:如果组件未预热且 ChatGenerator 具有 warm_up 方法。

返回值:

包含以下键的字典

  • "chat_generator_text":LLM 的文本输出,对调试很有用。
  • "output_names":每个都包含与相应模式匹配的消息列表。
  • "unmatched":不匹配任何输出模式的消息。

LLMMessagesRouter.to_dict

def to_dict() -> dict[str, Any]

将此组件序列化为字典。

返回值:

序列化后的组件(字典格式)。

LLMMessagesRouter.from_dict

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "LLMMessagesRouter"

从字典反序列化此组件。

参数:

  • data:此组件的字典表示。

返回值:

反序列化的组件实例。

模块 metadata_router

MetadataRouter

根据其元数据字段将文档或字节流路由到不同的连接。

init 方法中指定路由规则。如果文档或字节流不匹配任何规则,它将被路由到名为“unmatched”的连接。

使用示例

按元数据路由文档

from haystack import Document
from haystack.components.routers import MetadataRouter

docs = [Document(content="Paris is the capital of France.", meta={"language": "en"}),
        Document(content="Berlin ist die Haupststadt von Deutschland.", meta={"language": "de"})]

router = MetadataRouter(rules={"en": {"field": "meta.language", "operator": "==", "value": "en"}})

print(router.run(documents=docs))
# {'en': [Document(id=..., content: 'Paris is the capital of France.', meta: {'language': 'en'})],
# 'unmatched': [Document(id=..., content: 'Berlin ist die Haupststadt von Deutschland.', meta: {'language': 'de'})]}

按元数据路由字节流

from haystack.dataclasses import ByteStream
from haystack.components.routers import MetadataRouter

streams = [
    ByteStream.from_string("Hello world", meta={"language": "en"}),
    ByteStream.from_string("Bonjour le monde", meta={"language": "fr"})
]

router = MetadataRouter(
    rules={"english": {"field": "meta.language", "operator": "==", "value": "en"}},
    output_type=list[ByteStream]
)

result = router.run(documents=streams)
# {'english': [ByteStream(...)], 'unmatched': [ByteStream(...)]}

MetadataRouter.__init__

def __init__(rules: dict[str, dict],
             output_type: type = list[Document]) -> None

初始化 MetadataRouter 组件。

参数:

  • rules:一个字典,定义如何根据文档或字节流的元数据将其路由到输出连接。键是输出连接名称,值是 Haystack 中 过滤表达式 的字典。例如
{
"edge_1": {
    "operator": "AND",
    "conditions": [
        {"field": "meta.created_at", "operator": ">=", "value": "2023-01-01"},
        {"field": "meta.created_at", "operator": "<", "value": "2023-04-01"},
    ],
},
"edge_2": {
    "operator": "AND",
    "conditions": [
        {"field": "meta.created_at", "operator": ">=", "value": "2023-04-01"},
        {"field": "meta.created_at", "operator": "<", "value": "2023-07-01"},
    ],
},
"edge_3": {
    "operator": "AND",
    "conditions": [
        {"field": "meta.created_at", "operator": ">=", "value": "2023-07-01"},
        {"field": "meta.created_at", "operator": "<", "value": "2023-10-01"},
    ],
},
"edge_4": {
    "operator": "AND",
    "conditions": [
        {"field": "meta.created_at", "operator": ">=", "value": "2023-10-01"},
        {"field": "meta.created_at", "operator": "<", "value": "2024-01-01"},
    ],
},
}

:param output_type: 生成的输出类型。可以指定文档或字节流列表。

MetadataRouter.run

def run(documents: Union[list[Document], list[ByteStream]])

根据其元数据字段将文档或字节流路由到不同的连接。

如果文档或字节流不匹配任何规则,它将被路由到名为“unmatched”的连接。

参数:

  • documents: Document 对象列表。DocumentByteStream 对象,根据其元数据进行路由。

返回值:

一个字典,其中键是输出连接的名称(包括"unmatched"),值是与相应规则匹配的DocumentByteStream 对象列表。

MetadataRouter.to_dict

def to_dict() -> dict[str, Any]

将此组件序列化为字典。

返回值:

序列化后的组件(字典格式)。

MetadataRouter.from_dict

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MetadataRouter"

从字典反序列化此组件。

参数:

  • data:此组件的字典表示。

返回值:

反序列化的组件实例。

模块 text_language_router

TextLanguageRouter

根据文本字符串的语言将其路由到不同的输出连接。

在初始化期间提供语言列表。如果文档的文本不匹配任何指定的语言,则元数据值将设置为“unmatched”。要根据文档的语言路由文档,请使用 DocumentLanguageClassifier 组件,然后使用 MetaDataRouter。

使用示例

from haystack import Pipeline, Document
from haystack.components.routers import TextLanguageRouter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever

document_store = InMemoryDocumentStore()
document_store.write_documents([Document(content="Elvis Presley was an American singer and actor.")])

p = Pipeline()
p.add_component(instance=TextLanguageRouter(languages=["en"]), name="text_language_router")
p.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
p.connect("text_language_router.en", "retriever.query")

result = p.run({"text_language_router": {"text": "Who was Elvis Presley?"}})
assert result["retriever"]["documents"][0].content == "Elvis Presley was an American singer and actor."

result = p.run({"text_language_router": {"text": "ένα ελληνικό κείμενο"}})
assert result["text_language_router"]["unmatched"] == "ένα ελληνικό κείμενο"

TextLanguageRouter.__init__

def __init__(languages: Optional[list[str]] = None)

初始化 TextLanguageRouter 组件。

参数:

  • languages:ISO 语言代码列表。请参阅 langdetect 文档中支持的语言。如果未指定,则默认为 ["en"]。

TextLanguageRouter.run

def run(text: str) -> dict[str, str]

根据文本字符串的语言将其路由到不同的输出连接。

如果文档的文本不匹配任何指定的语言,则元数据值将设置为“unmatched”。

参数:

  • text:要路由的文本字符串。

引发:

  • TypeError: 如果输入不是字符串。

返回值:

一个字典,其中键是语言(或"unmatched"),值是文本。

模块 transformers_text_router

TransformersTextRouter

根据类别标签将文本字符串路由到不同的连接。

标签特定于每个模型,可以在 Hugging Face 上的描述中找到。

使用示例

from haystack.core.pipeline import Pipeline
from haystack.components.routers import TransformersTextRouter
from haystack.components.builders import PromptBuilder
from haystack.components.generators import HuggingFaceLocalGenerator

p = Pipeline()
p.add_component(
    instance=TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection"),
    name="text_router"
)
p.add_component(
    instance=PromptBuilder(template="Answer the question: {{query}}\nAnswer:"),
    name="english_prompt_builder"
)
p.add_component(
    instance=PromptBuilder(template="Beantworte die Frage: {{query}}\nAntwort:"),
    name="german_prompt_builder"
)

p.add_component(
    instance=HuggingFaceLocalGenerator(model="DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1"),
    name="german_llm"
)
p.add_component(
    instance=HuggingFaceLocalGenerator(model="microsoft/Phi-3-mini-4k-instruct"),
    name="english_llm"
)

p.connect("text_router.en", "english_prompt_builder.query")
p.connect("text_router.de", "german_prompt_builder.query")
p.connect("english_prompt_builder.prompt", "english_llm.prompt")
p.connect("german_prompt_builder.prompt", "german_llm.prompt")

# English Example
print(p.run({"text_router": {"text": "What is the capital of Germany?"}}))

# German Example
print(p.run({"text_router": {"text": "Was ist die Hauptstadt von Deutschland?"}}))

TransformersTextRouter.__init__

def __init__(model: str,
             labels: Optional[list[str]] = None,
             device: Optional[ComponentDevice] = None,
             token: Optional[Secret] = Secret.from_env_var(
                 ["HF_API_TOKEN", "HF_TOKEN"], strict=False),
             huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None)

初始化 TransformersTextRouter 组件。

参数:

  • model:用于文本分类的 Hugging Face 模型的名称或路径。
  • labels:标签列表。如果未提供,组件将使用transformers.AutoConfig.from_pretrained.
  • device:加载模型的设备。如果None,自动选择默认设备。如果设备或设备映射在huggingface_pipeline_kwargs 中指定了设备/设备映射,它将覆盖此参数。
  • token:用于从 Hugging Face 下载私有模型的 API 令牌。如果True,则使用HF_API_TOKENHF_TOKEN 环境变量。要生成这些令牌,请运行transformers-cli login.
  • huggingface_pipeline_kwargs:用于初始化 Hugging Face 文本分类管道的关键字参数字典。

TransformersTextRouter.warm_up

def warm_up()

Initializes the component.

TransformersTextRouter.to_dict

def to_dict() -> dict[str, Any]

将组件序列化为字典。

返回值:

包含序列化数据的字典。

TransformersTextRouter.from_dict

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TransformersTextRouter"

从字典反序列化组件。

参数:

  • data: 要反序列化的字典。

返回值:

反序列化后的组件。

TransformersTextRouter.run

def run(text: str) -> dict[str, str]

根据类别标签将文本字符串路由到不同的连接。

参数:

  • text:要路由的文本字符串。

引发:

  • TypeError:如果输入不是字符串。
  • RuntimeError:如果由于之前未调用 warm_up() 而未加载管道。

返回值:

一个字典,其中键是标签,值是文本。

模块 zero_shot_text_router

TransformersZeroShotTextRouter

根据类别标签将文本字符串路由到不同的连接。

初始化组件时指定分类的标签集。

使用示例

from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.core.pipeline import Pipeline
from haystack.components.routers import TransformersZeroShotTextRouter
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers import InMemoryEmbeddingRetriever

document_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="intfloat/e5-base-v2")
doc_embedder.warm_up()
docs = [
    Document(
        content="Germany, officially the Federal Republic of Germany, is a country in the western region of "
        "Central Europe. The nation's capital and most populous city is Berlin and its main financial centre "
        "is Frankfurt; the largest urban area is the Ruhr."
    ),
    Document(
        content="France, officially the French Republic, is a country located primarily in Western Europe. "
        "France is a unitary semi-presidential republic with its capital in Paris, the country's largest city "
        "and main cultural and commercial centre; other major urban areas include Marseille, Lyon, Toulouse, "
        "Lille, Bordeaux, Strasbourg, Nantes and Nice."
    )
]
docs_with_embeddings = doc_embedder.run(docs)
document_store.write_documents(docs_with_embeddings["documents"])

p = Pipeline()
p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router")
p.add_component(
    instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="passage: "),
    name="passage_embedder"
)
p.add_component(
    instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="query: "),
    name="query_embedder"
)
p.add_component(
    instance=InMemoryEmbeddingRetriever(document_store=document_store),
    name="query_retriever"
)
p.add_component(
    instance=InMemoryEmbeddingRetriever(document_store=document_store),
    name="passage_retriever"
)

p.connect("text_router.passage", "passage_embedder.text")
p.connect("passage_embedder.embedding", "passage_retriever.query_embedding")
p.connect("text_router.query", "query_embedder.text")
p.connect("query_embedder.embedding", "query_retriever.query_embedding")

# Query Example
p.run({"text_router": {"text": "What is the capital of Germany?"}})

# Passage Example
p.run({
    "text_router":{
        "text": "The United Kingdom of Great Britain and Northern Ireland, commonly known as the "            "United Kingdom (UK) or Britain, is a country in Northwestern Europe, off the north-western coast of "            "the continental mainland."
    }
})

TransformersZeroShotTextRouter.__init__

def __init__(labels: list[str],
             multi_label: bool = False,
             model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
             device: Optional[ComponentDevice] = None,
             token: Optional[Secret] = Secret.from_env_var(
                 ["HF_API_TOKEN", "HF_TOKEN"], strict=False),
             huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None)

初始化 TransformersZeroShotTextRouter 组件。

参数:

  • labels:用于分类的标签集。可以是单个标签、逗号分隔的标签字符串或标签列表。
  • multi_label:指示是否可以有多个真实标签。如果False,则标签分数归一化,使其总和为每个序列的 1。如果True,则标签被视为独立,并且通过对蕴含分数与矛盾分数进行 softmax 来对每个候选的概率进行归一化。
  • model:用于零样本文本分类的 Hugging Face 模型的名称或路径。
  • device:加载模型的设备。如果None,自动选择默认设备。如果设备或设备映射在huggingface_pipeline_kwargs 中指定了设备/设备映射,它将覆盖此参数。
  • token:用于从 Hugging Face 下载私有模型的 API 令牌。如果True,则使用HF_API_TOKENHF_TOKEN 环境变量。要生成这些令牌,请运行transformers-cli login.
  • huggingface_pipeline_kwargs:用于初始化 Hugging Face 零样本文本分类的关键字参数字典。

TransformersZeroShotTextRouter.warm_up

def warm_up()

Initializes the component.

TransformersZeroShotTextRouter.to_dict

def to_dict() -> dict[str, Any]

将组件序列化为字典。

返回值:

包含序列化数据的字典。

TransformersZeroShotTextRouter.from_dict

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotTextRouter"

从字典反序列化组件。

参数:

  • data: 要反序列化的字典。

返回值:

反序列化后的组件。

TransformersZeroShotTextRouter.run

def run(text: str) -> dict[str, str]

根据类别标签将文本字符串路由到不同的连接。

参数:

  • text:要路由的文本字符串。

引发:

  • TypeError:如果输入不是字符串。
  • RuntimeError:如果由于之前未调用 warm_up() 而未加载管道。

返回值:

一个字典,其中键是标签,值是文本。