mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
TransformersDocumentClassifier replacing FARMClassifier (#1540)
* Initial draft of TransformersClassifier * Add transformers classifier implementation * Add test for SentenceTransformersClassifier * Add truncation and corresponding test case to Classifier * Add zero-shot classification and test * Add document classifier documentation * Add latest docstring and tutorial changes * print meta data with print_documents() * Add latest docstring and tutorial changes * Remove top_k param from Classifier usage example * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
a20eec3098
commit
24483d7bad
108
docs/_src/api/api/document_classifier.md
Normal file
108
docs/_src/api/api/document_classifier.md
Normal file
@ -0,0 +1,108 @@
|
||||
<a name="base"></a>
|
||||
# Module base
|
||||
|
||||
<a name="base.BaseDocumentClassifier"></a>
|
||||
## BaseDocumentClassifier Objects
|
||||
|
||||
```python
|
||||
class BaseDocumentClassifier(BaseComponent)
|
||||
```
|
||||
|
||||
<a name="base.BaseDocumentClassifier.timing"></a>
|
||||
#### timing
|
||||
|
||||
```python
|
||||
| timing(fn, attr_name)
|
||||
```
|
||||
|
||||
Wrapper method used to time functions.
|
||||
|
||||
<a name="transformers"></a>
|
||||
# Module transformers
|
||||
|
||||
<a name="transformers.TransformersDocumentClassifier"></a>
|
||||
## TransformersDocumentClassifier Objects
|
||||
|
||||
```python
|
||||
class TransformersDocumentClassifier(BaseDocumentClassifier)
|
||||
```
|
||||
|
||||
Transformer based model for document classification using the HuggingFace's transformers framework
|
||||
(https://github.com/huggingface/transformers).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
|
||||
This node classifies documents and adds the output from the classification step to the document's meta data.
|
||||
The meta field of the document is a dictionary with the following format:
|
||||
'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }
|
||||
|
||||
With this document_classifier, you can directly get predictions via predict()
|
||||
|
||||
Usage example:
|
||||
...
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
document_classifier = TransformersDocumentClassifier(model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion")
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
p.add_node(component=document_classifier, name="Classifier", inputs=["Retriever"])
|
||||
res = p.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
params={"Retriever": {"top_k": 10}}
|
||||
)
|
||||
|
||||
__print the classification results__
|
||||
|
||||
print_documents(res, max_text_len=100, print_meta=True)
|
||||
__or access the predicted class label directly__
|
||||
|
||||
res["documents"][0].to_dict()["meta"]["classification"]["label"]
|
||||
|
||||
<a name="transformers.TransformersDocumentClassifier.__init__"></a>
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(model_name_or_path: str = "bhadresh-savani/distilbert-base-uncased-emotion", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: int = 0, return_all_scores: bool = False, task: str = 'text-classification', labels: Optional[List[str]] = None)
|
||||
```
|
||||
|
||||
Load a text classification model from Transformers.
|
||||
Available models for the task of text-classification include:
|
||||
- ``'bhadresh-savani/distilbert-base-uncased-emotion'``
|
||||
- ``'Hate-speech-CNERG/dehatebert-mono-english'``
|
||||
|
||||
Available models for the task of zero-shot-classification include:
|
||||
- ``'valhalla/distilbart-mnli-12-3'``
|
||||
- ``'cross-encoder/nli-distilroberta-base'``
|
||||
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
|
||||
Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `model_name_or_path`: Directory of a saved model or the name of a public model e.g. 'bhadresh-savani/distilbert-base-uncased-emotion'.
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
- `tokenizer`: Name of the tokenizer (usually the same as model)
|
||||
- `use_gpu`: If < 0, then use cpu. If >= 0, this is the ordinal of the gpu to use
|
||||
- `return_all_scores`: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'.
|
||||
- `task`: 'text-classification' or 'zero-shot-classification'
|
||||
- `labels`: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g.,
|
||||
["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is "<cls> sequence to
|
||||
classify <sep> This example is LABEL . <sep>" and the model predicts whether that sequence is a contradiction
|
||||
or an entailment.
|
||||
|
||||
<a name="transformers.TransformersDocumentClassifier.predict"></a>
|
||||
#### predict
|
||||
|
||||
```python
|
||||
| predict(documents: List[Document]) -> List[Document]
|
||||
```
|
||||
|
||||
Returns documents containing classification result in meta field
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `documents`: List of Document to classify
|
||||
|
||||
**Returns**:
|
||||
|
||||
List of Document enriched with meta information
|
||||
|
||||
@ -17,4 +17,5 @@ pydoc-markdown pydoc-markdown-graph-retriever.yml
|
||||
pydoc-markdown pydoc-markdown-evaluation.yml
|
||||
pydoc-markdown pydoc-markdown-ranker.yml
|
||||
pydoc-markdown pydoc-markdown-question-generator.yml
|
||||
pydoc-markdown pydoc-markdown-document-classifier.yml
|
||||
|
||||
|
||||
18
docs/_src/api/api/pydoc-markdown-document-classifier.yml
Normal file
18
docs/_src/api/api/pydoc-markdown-document-classifier.yml
Normal file
@ -0,0 +1,18 @@
|
||||
loaders:
|
||||
- type: python
|
||||
search_path: [../../../../haystack/document_classifier]
|
||||
modules: ['base', 'transformers']
|
||||
ignore_when_discovered: ['__init__']
|
||||
processor:
|
||||
- type: filter
|
||||
expression: not name.startswith('_') and default()
|
||||
- documented_only: true
|
||||
- do_not_filter_modules: false
|
||||
- skip_empty_modules: true
|
||||
renderer:
|
||||
type: markdown
|
||||
descriptive_class_title: true
|
||||
descriptive_module_title: true
|
||||
add_method_class_prefix: false
|
||||
add_member_class_prefix: false
|
||||
filename: document_classifier.md
|
||||
1
haystack/document_classifier/__init__.py
Normal file
1
haystack/document_classifier/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.document_classifier.transformers import TransformersDocumentClassifier
|
||||
57
haystack/document_classifier/base.py
Normal file
57
haystack/document_classifier/base.py
Normal file
@ -0,0 +1,57 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
|
||||
|
||||
from haystack import Document, BaseComponent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseDocumentClassifier(BaseComponent):
|
||||
outgoing_edges = 1
|
||||
query_count = 0
|
||||
query_time = 0
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, documents: List[Document]):
|
||||
pass
|
||||
|
||||
def run(self, query: str, documents: List[Document]): # type: ignore
|
||||
self.query_count += 1
|
||||
if documents:
|
||||
predict = self.timing(self.predict, "query_time")
|
||||
results = predict(documents=documents)
|
||||
else:
|
||||
results = []
|
||||
|
||||
document_ids = [doc.id for doc in results]
|
||||
logger.debug(f"Retrieved documents with IDs: {document_ids}")
|
||||
output = {"documents": results}
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def timing(self, fn, attr_name):
|
||||
"""Wrapper method used to time functions. """
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if attr_name not in self.__dict__:
|
||||
self.__dict__[attr_name] = 0
|
||||
tic = perf_counter()
|
||||
ret = fn(*args, **kwargs)
|
||||
toc = perf_counter()
|
||||
self.__dict__[attr_name] += toc - tic
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
def print_time(self):
|
||||
print("Classifier (Speed)")
|
||||
print("---------------")
|
||||
if not self.query_count:
|
||||
print("No querying performed via Classifier.run()")
|
||||
else:
|
||||
print(f"Queries Performed: {self.query_count}")
|
||||
print(f"Query time: {self.query_time}s")
|
||||
print(f"{self.query_time / self.query_count} seconds per query")
|
||||
121
haystack/document_classifier/transformers.py
Normal file
121
haystack/document_classifier/transformers.py
Normal file
@ -0,0 +1,121 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_classifier.base import BaseDocumentClassifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
"""
|
||||
Transformer based model for document classification using the HuggingFace's transformers framework
|
||||
(https://github.com/huggingface/transformers).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
|
||||
This node classifies documents and adds the output from the classification step to the document's meta data.
|
||||
The meta field of the document is a dictionary with the following format:
|
||||
'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }
|
||||
|
||||
With this document_classifier, you can directly get predictions via predict()
|
||||
|
||||
Usage example:
|
||||
...
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
document_classifier = TransformersDocumentClassifier(model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion")
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
p.add_node(component=document_classifier, name="Classifier", inputs=["Retriever"])
|
||||
res = p.run(
|
||||
query="Who is the father of Arya Stark?",
|
||||
params={"Retriever": {"top_k": 10}}
|
||||
)
|
||||
|
||||
# print the classification results
|
||||
print_documents(res, max_text_len=100, print_meta=True)
|
||||
# or access the predicted class label directly
|
||||
res["documents"][0].to_dict()["meta"]["classification"]["label"]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "bhadresh-savani/distilbert-base-uncased-emotion",
|
||||
model_version: Optional[str] = None,
|
||||
tokenizer: Optional[str] = None,
|
||||
use_gpu: int = 0,
|
||||
return_all_scores: bool = False,
|
||||
task: str = 'text-classification',
|
||||
labels: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Load a text classification model from Transformers.
|
||||
Available models for the task of text-classification include:
|
||||
- ``'bhadresh-savani/distilbert-base-uncased-emotion'``
|
||||
- ``'Hate-speech-CNERG/dehatebert-mono-english'``
|
||||
|
||||
Available models for the task of zero-shot-classification include:
|
||||
- ``'valhalla/distilbart-mnli-12-3'``
|
||||
- ``'cross-encoder/nli-distilroberta-base'``
|
||||
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
|
||||
Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
|
||||
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bhadresh-savani/distilbert-base-uncased-emotion'.
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param tokenizer: Name of the tokenizer (usually the same as model)
|
||||
:param use_gpu: If < 0, then use cpu. If >= 0, this is the ordinal of the gpu to use
|
||||
:param return_all_scores: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'.
|
||||
:param task: 'text-classification' or 'zero-shot-classification'
|
||||
:param labels: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g.,
|
||||
["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is "<cls> sequence to
|
||||
classify <sep> This example is LABEL . <sep>" and the model predicts whether that sequence is a contradiction
|
||||
or an entailment.
|
||||
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
|
||||
use_gpu=use_gpu, return_all_scores=return_all_scores, labels=labels, task=task
|
||||
)
|
||||
if labels and task == 'text-classification':
|
||||
logger.warning(f'Provided labels {labels} will be ignored for task text-classification. Set task to '
|
||||
f'zero-shot-classification to use labels.')
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = model_name_or_path
|
||||
if task == 'zero-shot-classification':
|
||||
self.model = pipeline(task=task, model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version)
|
||||
elif task == 'text-classification':
|
||||
self.model = pipeline(task=task, model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version, return_all_scores=return_all_scores)
|
||||
self.return_all_scores = return_all_scores
|
||||
self.labels = labels
|
||||
self.task = task
|
||||
|
||||
def predict(self, documents: List[Document]) -> List[Document]:
|
||||
"""
|
||||
Returns documents containing classification result in meta field
|
||||
|
||||
:param documents: List of Document to classify
|
||||
:return: List of Document enriched with meta information
|
||||
|
||||
"""
|
||||
texts = [doc.text for doc in documents]
|
||||
if self.task == 'zero-shot-classification':
|
||||
predictions = self.model(texts, candidate_labels=self.labels, truncation=True)
|
||||
elif self.task == 'text-classification':
|
||||
predictions = self.model(texts, return_all_scores=self.return_all_scores, truncation=True)
|
||||
|
||||
classified_docs: List[Document] = []
|
||||
|
||||
for prediction, doc in zip(predictions, documents):
|
||||
cur_doc = doc
|
||||
cur_doc.meta["classification"] = prediction
|
||||
if self.task == 'zero-shot-classification':
|
||||
cur_doc.meta["classification"]["label"] = cur_doc.meta["classification"]["labels"][0]
|
||||
classified_docs.append(cur_doc)
|
||||
|
||||
return classified_docs
|
||||
@ -4,7 +4,7 @@ from itertools import islice
|
||||
import logging
|
||||
import pprint
|
||||
import pandas as pd
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional
|
||||
from haystack.document_store.sql import DocumentORM
|
||||
import subprocess
|
||||
import time
|
||||
@ -134,8 +134,7 @@ def print_answers(results: dict, details: str = "all"):
|
||||
pp.pprint(results)
|
||||
|
||||
|
||||
|
||||
def print_documents(results: dict, max_text_len: int=None):
|
||||
def print_documents(results: dict, max_text_len: Optional[int] = None, print_meta: bool = False):
|
||||
print(f"Query: {results['query']}")
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
for d in results["documents"]:
|
||||
@ -147,6 +146,8 @@ def print_documents(results: dict, max_text_len: int=None):
|
||||
"name": d["meta"]["name"],
|
||||
"text": new_text
|
||||
}
|
||||
if print_meta:
|
||||
results["meta"] = d["meta"]
|
||||
pp.pprint(results)
|
||||
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ from haystack.document_store.milvus import MilvusDocumentStore
|
||||
from haystack.generator.transformers import RAGenerator, RAGeneratorType
|
||||
from haystack.modeling.infer import Inferencer, QAInferencer
|
||||
from haystack.ranker import SentenceTransformersRanker
|
||||
from haystack.document_classifier.transformers import TransformersDocumentClassifier
|
||||
|
||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
||||
|
||||
@ -337,6 +338,23 @@ def ranker():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
|
||||
use_gpu=-1
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zero_shot_document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="cross-encoder/nli-distilroberta-base",
|
||||
use_gpu=-1,
|
||||
task="zero-shot-classification",
|
||||
labels=["negative", "positive"]
|
||||
)
|
||||
|
||||
|
||||
# TODO Fix bug in test_no_answer_output when using
|
||||
# @pytest.fixture(params=["farm", "transformers"])
|
||||
@pytest.fixture(params=["farm"], scope="module")
|
||||
|
||||
48
test/test_document_classifier.py
Normal file
48
test/test_document_classifier.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_classifier.base import BaseDocumentClassifier
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_document_classifier(document_classifier):
|
||||
assert isinstance(document_classifier, BaseDocumentClassifier)
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
text="""That's good. I like it."""*700, # extra long text to check truncation
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
text="""That's bad. I don't like it.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
]
|
||||
results = document_classifier.predict(documents=docs)
|
||||
expected_labels = ["joy", "sadness"]
|
||||
for i, doc in enumerate(results):
|
||||
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_zero_shot_document_classifier(zero_shot_document_classifier):
|
||||
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
text="""That's good. I like it."""*700, # extra long text to check truncation
|
||||
meta={"name": "0"},
|
||||
id="1",
|
||||
),
|
||||
Document(
|
||||
text="""That's bad. I don't like it.""",
|
||||
meta={"name": "1"},
|
||||
id="2",
|
||||
),
|
||||
]
|
||||
results = zero_shot_document_classifier.predict(documents=docs)
|
||||
expected_labels = ["positive", "negative"]
|
||||
for i, doc in enumerate(results):
|
||||
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
||||
Loading…
x
Reference in New Issue
Block a user