文档API 参考📓 教程🧑‍🍳 食谱🤝 集成💜 Discord🎨 Studio
文档

TransformersZeroShotDocumentClassifier

根据提供的标签对文档进行分类,并将其添加到元数据中。

pipeline 中的最常见位置MetadataRouter 之前
必需的初始化变量“model”: 用于零样本文档分类的 Hugging Face 模型名称或路径

”labels”: 要将每个文档分类到的可能的类标签集合,例如 ["positive", "negative"]。标签取决于所选模型。
强制运行变量“documents”: 要分类的文档列表
输出变量“documents”: 一系列已处理的文档,其中添加了一个“classification”元数据字段
API 参考Classifiers (分类器)
GitHub 链接https://github.com/deepset-ai/haystack/blob/main/haystack/components/classifiers/zero_shot_document_classifier.py

概述

TransformersZeroShotDocumentClassifier 组件根据您设置的标签执行文档的零样本分类,并将预测的标签添加到其元数据中。

该组件使用 Hugging Face 管道进行零样本分类。
要初始化该组件,请提供要用于分类的模型和标签集。
您还可以通过将multi_label 布尔值设置为 True 来配置该组件以允许多个标签为真。

默认情况下,分类在文档的内容字段上运行。如果您希望它在其他字段上运行,请将classification_field 设置为文档的元数据字段之一。

分类结果存储在每个文档元数据内的classification 字典中。如果multi_label 设置为True,您将在classification 字典中的details 键下找到每个标签的得分。

零样本分类任务的可用模型是
- valhalla/distilbart-mnli-12-3
- cross-encoder/nli-distilroberta-base
- cross-encoder/nli-deberta-v3-xsmall

用法

单独使用

from haystack import Document
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier

documents = [Document(id="0", content="Cats don't get teeth cavities."),
             Document(id="1", content="Cucumbers can be grown in water.")]
             
document_classifier = TransformersZeroShotDocumentClassifier(
    model="cross-encoder/nli-deberta-v3-xsmall",
    labels=["animals", "food"],
)

document_classifier.warm_up()
document_classifier.run(documents = documents)    

在 pipeline 中

以下是一个管道,该管道根据预定义的分类标签对文档进行分类
从搜索管道中检索

from haystack import Document
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.core.pipeline import Pipeline
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier

documents = [Document(id="0", content="Today was a nice day!"),
             Document(id="1", content="Yesterday was a bad day!")]

document_store = InMemoryDocumentStore()
retriever = InMemoryBM25Retriever(document_store=document_store)
document_classifier = TransformersZeroShotDocumentClassifier(
    model="cross-encoder/nli-deberta-v3-xsmall",
    labels=["positive", "negative"],
)

document_store.write_documents(documents)

pipeline = Pipeline()
pipeline.add_component(instance=retriever, name="retriever")
pipeline.add_component(instance=document_classifier, name="document_classifier")
pipeline.connect("retriever", "document_classifier")

queries = ["How was your day today?", "How was your day yesterday?"]
expected_predictions = ["positive", "negative"]

for idx, query in enumerate(queries):
    result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
    assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
    assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
            == expected_predictions[idx])