TransformersDocumentClassifier replacing FARMClassifier (#1540)

* Initial draft of TransformersClassifier

* Add transformers classifier implementation

* Add test for SentenceTransformersClassifier

* Add truncation and corresponding test case to Classifier

* Add zero-shot classification and test

* Add document classifier documentation

* Add latest docstring and tutorial changes

* print meta data with print_documents()

* Add latest docstring and tutorial changes

* Remove top_k param from Classifier usage example

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Julian Risch 2021-10-01 11:22:56 +02:00 committed by GitHub
parent a20eec3098
commit 24483d7bad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 376 additions and 3 deletions

View File

@ -0,0 +1,108 @@
<a name="base"></a>
# Module base
<a name="base.BaseDocumentClassifier"></a>
## BaseDocumentClassifier Objects
```python
class BaseDocumentClassifier(BaseComponent)
```
<a name="base.BaseDocumentClassifier.timing"></a>
#### timing
```python
| timing(fn, attr_name)
```
Wrapper method used to time functions.
<a name="transformers"></a>
# Module transformers
<a name="transformers.TransformersDocumentClassifier"></a>
## TransformersDocumentClassifier Objects
```python
class TransformersDocumentClassifier(BaseDocumentClassifier)
```
Transformer based model for document classification using the HuggingFace's transformers framework
(https://github.com/huggingface/transformers).
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
This node classifies documents and adds the output from the classification step to the document's meta data.
The meta field of the document is a dictionary with the following format:
'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }
With this document_classifier, you can directly get predictions via predict()
Usage example:
...
retriever = ElasticsearchRetriever(document_store=document_store)
document_classifier = TransformersDocumentClassifier(model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion")
p = Pipeline()
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
p.add_node(component=document_classifier, name="Classifier", inputs=["Retriever"])
res = p.run(
query="Who is the father of Arya Stark?",
params={"Retriever": {"top_k": 10}}
)
__print the classification results__
print_documents(res, max_text_len=100, print_meta=True)
__or access the predicted class label directly__
res["documents"][0].to_dict()["meta"]["classification"]["label"]
<a name="transformers.TransformersDocumentClassifier.__init__"></a>
#### \_\_init\_\_
```python
| __init__(model_name_or_path: str = "bhadresh-savani/distilbert-base-uncased-emotion", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: int = 0, return_all_scores: bool = False, task: str = 'text-classification', labels: Optional[List[str]] = None)
```
Load a text classification model from Transformers.
Available models for the task of text-classification include:
- ``'bhadresh-savani/distilbert-base-uncased-emotion'``
- ``'Hate-speech-CNERG/dehatebert-mono-english'``
Available models for the task of zero-shot-classification include:
- ``'valhalla/distilbart-mnli-12-3'``
- ``'cross-encoder/nli-distilroberta-base'``
See https://huggingface.co/models for full list of available models.
Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
**Arguments**:
- `model_name_or_path`: Directory of a saved model or the name of a public model e.g. 'bhadresh-savani/distilbert-base-uncased-emotion'.
See https://huggingface.co/models for full list of available models.
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
- `tokenizer`: Name of the tokenizer (usually the same as model)
- `use_gpu`: If < 0, then use cpu. If >= 0, this is the ordinal of the gpu to use
- `return_all_scores`: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'.
- `task`: 'text-classification' or 'zero-shot-classification'
- `labels`: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g.,
["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is "<cls> sequence to
classify <sep> This example is LABEL . <sep>" and the model predicts whether that sequence is a contradiction
or an entailment.
<a name="transformers.TransformersDocumentClassifier.predict"></a>
#### predict
```python
| predict(documents: List[Document]) -> List[Document]
```
Returns documents containing classification result in meta field
**Arguments**:
- `documents`: List of Document to classify
**Returns**:
List of Document enriched with meta information

View File

@ -17,4 +17,5 @@ pydoc-markdown pydoc-markdown-graph-retriever.yml
pydoc-markdown pydoc-markdown-evaluation.yml
pydoc-markdown pydoc-markdown-ranker.yml
pydoc-markdown pydoc-markdown-question-generator.yml
pydoc-markdown pydoc-markdown-document-classifier.yml

View File

@ -0,0 +1,18 @@
loaders:
- type: python
search_path: [../../../../haystack/document_classifier]
modules: ['base', 'transformers']
ignore_when_discovered: ['__init__']
processor:
- type: filter
expression: not name.startswith('_') and default()
- documented_only: true
- do_not_filter_modules: false
- skip_empty_modules: true
renderer:
type: markdown
descriptive_class_title: true
descriptive_module_title: true
add_method_class_prefix: false
add_member_class_prefix: false
filename: document_classifier.md

View File

@ -0,0 +1 @@
from haystack.document_classifier.transformers import TransformersDocumentClassifier

View File

@ -0,0 +1,57 @@
import logging
from abc import abstractmethod
from typing import List
from functools import wraps
from time import perf_counter
from haystack import Document, BaseComponent
logger = logging.getLogger(__name__)
class BaseDocumentClassifier(BaseComponent):
outgoing_edges = 1
query_count = 0
query_time = 0
@abstractmethod
def predict(self, documents: List[Document]):
pass
def run(self, query: str, documents: List[Document]): # type: ignore
self.query_count += 1
if documents:
predict = self.timing(self.predict, "query_time")
results = predict(documents=documents)
else:
results = []
document_ids = [doc.id for doc in results]
logger.debug(f"Retrieved documents with IDs: {document_ids}")
output = {"documents": results}
return output, "output_1"
def timing(self, fn, attr_name):
"""Wrapper method used to time functions. """
@wraps(fn)
def wrapper(*args, **kwargs):
if attr_name not in self.__dict__:
self.__dict__[attr_name] = 0
tic = perf_counter()
ret = fn(*args, **kwargs)
toc = perf_counter()
self.__dict__[attr_name] += toc - tic
return ret
return wrapper
def print_time(self):
print("Classifier (Speed)")
print("---------------")
if not self.query_count:
print("No querying performed via Classifier.run()")
else:
print(f"Queries Performed: {self.query_count}")
print(f"Query time: {self.query_time}s")
print(f"{self.query_time / self.query_count} seconds per query")

View File

@ -0,0 +1,121 @@
import logging
from typing import List, Optional
from transformers import pipeline
from haystack import Document
from haystack.document_classifier.base import BaseDocumentClassifier
logger = logging.getLogger(__name__)
class TransformersDocumentClassifier(BaseDocumentClassifier):
"""
Transformer based model for document classification using the HuggingFace's transformers framework
(https://github.com/huggingface/transformers).
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
This node classifies documents and adds the output from the classification step to the document's meta data.
The meta field of the document is a dictionary with the following format:
'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }
With this document_classifier, you can directly get predictions via predict()
Usage example:
...
retriever = ElasticsearchRetriever(document_store=document_store)
document_classifier = TransformersDocumentClassifier(model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion")
p = Pipeline()
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
p.add_node(component=document_classifier, name="Classifier", inputs=["Retriever"])
res = p.run(
query="Who is the father of Arya Stark?",
params={"Retriever": {"top_k": 10}}
)
# print the classification results
print_documents(res, max_text_len=100, print_meta=True)
# or access the predicted class label directly
res["documents"][0].to_dict()["meta"]["classification"]["label"]
"""
def __init__(
self,
model_name_or_path: str = "bhadresh-savani/distilbert-base-uncased-emotion",
model_version: Optional[str] = None,
tokenizer: Optional[str] = None,
use_gpu: int = 0,
return_all_scores: bool = False,
task: str = 'text-classification',
labels: Optional[List[str]] = None
):
"""
Load a text classification model from Transformers.
Available models for the task of text-classification include:
- ``'bhadresh-savani/distilbert-base-uncased-emotion'``
- ``'Hate-speech-CNERG/dehatebert-mono-english'``
Available models for the task of zero-shot-classification include:
- ``'valhalla/distilbart-mnli-12-3'``
- ``'cross-encoder/nli-distilroberta-base'``
See https://huggingface.co/models for full list of available models.
Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bhadresh-savani/distilbert-base-uncased-emotion'.
See https://huggingface.co/models for full list of available models.
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param tokenizer: Name of the tokenizer (usually the same as model)
:param use_gpu: If < 0, then use cpu. If >= 0, this is the ordinal of the gpu to use
:param return_all_scores: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'.
:param task: 'text-classification' or 'zero-shot-classification'
:param labels: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g.,
["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is "<cls> sequence to
classify <sep> This example is LABEL . <sep>" and the model predicts whether that sequence is a contradiction
or an entailment.
"""
# save init parameters to enable export of component config as YAML
self.set_config(
model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer,
use_gpu=use_gpu, return_all_scores=return_all_scores, labels=labels, task=task
)
if labels and task == 'text-classification':
logger.warning(f'Provided labels {labels} will be ignored for task text-classification. Set task to '
f'zero-shot-classification to use labels.')
if tokenizer is None:
tokenizer = model_name_or_path
if task == 'zero-shot-classification':
self.model = pipeline(task=task, model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version)
elif task == 'text-classification':
self.model = pipeline(task=task, model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version, return_all_scores=return_all_scores)
self.return_all_scores = return_all_scores
self.labels = labels
self.task = task
def predict(self, documents: List[Document]) -> List[Document]:
"""
Returns documents containing classification result in meta field
:param documents: List of Document to classify
:return: List of Document enriched with meta information
"""
texts = [doc.text for doc in documents]
if self.task == 'zero-shot-classification':
predictions = self.model(texts, candidate_labels=self.labels, truncation=True)
elif self.task == 'text-classification':
predictions = self.model(texts, return_all_scores=self.return_all_scores, truncation=True)
classified_docs: List[Document] = []
for prediction, doc in zip(predictions, documents):
cur_doc = doc
cur_doc.meta["classification"] = prediction
if self.task == 'zero-shot-classification':
cur_doc.meta["classification"]["label"] = cur_doc.meta["classification"]["labels"][0]
classified_docs.append(cur_doc)
return classified_docs

View File

@ -4,7 +4,7 @@ from itertools import islice
import logging
import pprint
import pandas as pd
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from haystack.document_store.sql import DocumentORM
import subprocess
import time
@ -134,8 +134,7 @@ def print_answers(results: dict, details: str = "all"):
pp.pprint(results)
def print_documents(results: dict, max_text_len: int=None):
def print_documents(results: dict, max_text_len: Optional[int] = None, print_meta: bool = False):
print(f"Query: {results['query']}")
pp = pprint.PrettyPrinter(indent=4)
for d in results["documents"]:
@ -147,6 +146,8 @@ def print_documents(results: dict, max_text_len: int=None):
"name": d["meta"]["name"],
"text": new_text
}
if print_meta:
results["meta"] = d["meta"]
pp.pprint(results)

View File

@ -19,6 +19,7 @@ from haystack.document_store.milvus import MilvusDocumentStore
from haystack.generator.transformers import RAGenerator, RAGeneratorType
from haystack.modeling.infer import Inferencer, QAInferencer
from haystack.ranker import SentenceTransformersRanker
from haystack.document_classifier.transformers import TransformersDocumentClassifier
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
@ -337,6 +338,23 @@ def ranker():
)
@pytest.fixture(scope="module")
def document_classifier():
return TransformersDocumentClassifier(
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
use_gpu=-1
)
@pytest.fixture(scope="module")
def zero_shot_document_classifier():
return TransformersDocumentClassifier(
model_name_or_path="cross-encoder/nli-distilroberta-base",
use_gpu=-1,
task="zero-shot-classification",
labels=["negative", "positive"]
)
# TODO Fix bug in test_no_answer_output when using
# @pytest.fixture(params=["farm", "transformers"])
@pytest.fixture(params=["farm"], scope="module")

View File

@ -0,0 +1,48 @@
import pytest
from haystack import Document
from haystack.document_classifier.base import BaseDocumentClassifier
@pytest.mark.slow
def test_document_classifier(document_classifier):
assert isinstance(document_classifier, BaseDocumentClassifier)
docs = [
Document(
text="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
text="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = document_classifier.predict(documents=docs)
expected_labels = ["joy", "sadness"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
@pytest.mark.slow
def test_zero_shot_document_classifier(zero_shot_document_classifier):
assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier)
docs = [
Document(
text="""That's good. I like it."""*700, # extra long text to check truncation
meta={"name": "0"},
id="1",
),
Document(
text="""That's bad. I don't like it.""",
meta={"name": "1"},
id="2",
),
]
results = zero_shot_document_classifier.predict(documents=docs)
expected_labels = ["positive", "negative"]
for i, doc in enumerate(results):
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]