TransformersZeroShotDocumentClassifier
根据提供的标签对文档进行分类,并将其添加到元数据中。
| pipeline 中的最常见位置 | 在 MetadataRouter 之前 |
| 必需的初始化变量 | “model”: 用于零样本文档分类的 Hugging Face 模型名称或路径 ”labels”: 要将每个文档分类到的可能的类标签集合,例如 ["positive", "negative"]。标签取决于所选模型。 |
| 强制运行变量 | “documents”: 要分类的文档列表 |
| 输出变量 | “documents”: 一系列已处理的文档,其中添加了一个“classification”元数据字段 |
| API 参考 | Classifiers (分类器) |
| GitHub 链接 | https://github.com/deepset-ai/haystack/blob/main/haystack/components/classifiers/zero_shot_document_classifier.py |
概述
该TransformersZeroShotDocumentClassifier 组件根据您设置的标签执行文档的零样本分类,并将预测的标签添加到其元数据中。
该组件使用 Hugging Face 管道进行零样本分类。
要初始化该组件,请提供要用于分类的模型和标签集。
您还可以通过将multi_label 布尔值设置为 True 来配置该组件以允许多个标签为真。
默认情况下,分类在文档的内容字段上运行。如果您希望它在其他字段上运行,请将classification_field 设置为文档的元数据字段之一。
分类结果存储在每个文档元数据内的classification 字典中。如果multi_label 设置为True,您将在classification 字典中的details 键下找到每个标签的得分。
零样本分类任务的可用模型是
- valhalla/distilbart-mnli-12-3
- cross-encoder/nli-distilroberta-base
- cross-encoder/nli-deberta-v3-xsmall
用法
单独使用
from haystack import Document
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
documents = [Document(id="0", content="Cats don't get teeth cavities."),
Document(id="1", content="Cucumbers can be grown in water.")]
document_classifier = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall",
labels=["animals", "food"],
)
document_classifier.warm_up()
document_classifier.run(documents = documents)
在 pipeline 中
以下是一个管道,该管道根据预定义的分类标签对文档进行分类
从搜索管道中检索
from haystack import Document
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.core.pipeline import Pipeline
from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
documents = [Document(id="0", content="Today was a nice day!"),
Document(id="1", content="Yesterday was a bad day!")]
document_store = InMemoryDocumentStore()
retriever = InMemoryBM25Retriever(document_store=document_store)
document_classifier = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall",
labels=["positive", "negative"],
)
document_store.write_documents(documents)
pipeline = Pipeline()
pipeline.add_component(instance=retriever, name="retriever")
pipeline.add_component(instance=document_classifier, name="document_classifier")
pipeline.connect("retriever", "document_classifier")
queries = ["How was your day today?", "How was your day yesterday?"]
expected_predictions = ["positive", "negative"]
for idx, query in enumerate(queries):
result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
== expected_predictions[idx])
更新于 大约 1 年前
