From 24483d7bad9132e0551457f5dda9a1e16e99a37e Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Fri, 1 Oct 2021 11:22:56 +0200 Subject: [PATCH] TransformersDocumentClassifier replacing FARMClassifier (#1540) * Initial draft of TransformersClassifier * Add transformers classifier implementation * Add test for SentenceTransformersClassifier * Add truncation and corresponding test case to Classifier * Add zero-shot classification and test * Add document classifier documentation * Add latest docstring and tutorial changes * print meta data with print_documents() * Add latest docstring and tutorial changes * Remove top_k param from Classifier usage example * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/_src/api/api/document_classifier.md | 108 ++++++++++++++++ docs/_src/api/api/generate_docstrings.sh | 1 + .../pydoc-markdown-document-classifier.yml | 18 +++ haystack/document_classifier/__init__.py | 1 + haystack/document_classifier/base.py | 57 +++++++++ haystack/document_classifier/transformers.py | 121 ++++++++++++++++++ haystack/utils.py | 7 +- test/conftest.py | 18 +++ test/test_document_classifier.py | 48 +++++++ 9 files changed, 376 insertions(+), 3 deletions(-) create mode 100644 docs/_src/api/api/document_classifier.md create mode 100644 docs/_src/api/api/pydoc-markdown-document-classifier.yml create mode 100644 haystack/document_classifier/__init__.py create mode 100644 haystack/document_classifier/base.py create mode 100644 haystack/document_classifier/transformers.py create mode 100644 test/test_document_classifier.py diff --git a/docs/_src/api/api/document_classifier.md b/docs/_src/api/api/document_classifier.md new file mode 100644 index 000000000..4ca1e2063 --- /dev/null +++ b/docs/_src/api/api/document_classifier.md @@ -0,0 +1,108 @@ + +# Module base + + +## BaseDocumentClassifier Objects + +```python +class BaseDocumentClassifier(BaseComponent) +``` + + +#### timing + +```python + | timing(fn, attr_name) +``` + +Wrapper method used to time functions. + + +# Module transformers + + +## TransformersDocumentClassifier Objects + +```python +class TransformersDocumentClassifier(BaseDocumentClassifier) +``` + +Transformer based model for document classification using the HuggingFace's transformers framework +(https://github.com/huggingface/transformers). +While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same. +This node classifies documents and adds the output from the classification step to the document's meta data. +The meta field of the document is a dictionary with the following format: +'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} } + +With this document_classifier, you can directly get predictions via predict() + +Usage example: +... +retriever = ElasticsearchRetriever(document_store=document_store) +document_classifier = TransformersDocumentClassifier(model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion") +p = Pipeline() +p.add_node(component=retriever, name="Retriever", inputs=["Query"]) +p.add_node(component=document_classifier, name="Classifier", inputs=["Retriever"]) +res = p.run( + query="Who is the father of Arya Stark?", + params={"Retriever": {"top_k": 10}} +) + +__print the classification results__ + +print_documents(res, max_text_len=100, print_meta=True) +__or access the predicted class label directly__ + +res["documents"][0].to_dict()["meta"]["classification"]["label"] + + +#### \_\_init\_\_ + +```python + | __init__(model_name_or_path: str = "bhadresh-savani/distilbert-base-uncased-emotion", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: int = 0, return_all_scores: bool = False, task: str = 'text-classification', labels: Optional[List[str]] = None) +``` + +Load a text classification model from Transformers. +Available models for the task of text-classification include: +- ``'bhadresh-savani/distilbert-base-uncased-emotion'`` +- ``'Hate-speech-CNERG/dehatebert-mono-english'`` + +Available models for the task of zero-shot-classification include: +- ``'valhalla/distilbart-mnli-12-3'`` +- ``'cross-encoder/nli-distilroberta-base'`` + +See https://huggingface.co/models for full list of available models. +Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads +Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli + +**Arguments**: + +- `model_name_or_path`: Directory of a saved model or the name of a public model e.g. 'bhadresh-savani/distilbert-base-uncased-emotion'. +See https://huggingface.co/models for full list of available models. +- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. +- `tokenizer`: Name of the tokenizer (usually the same as model) +- `use_gpu`: If < 0, then use cpu. If >= 0, this is the ordinal of the gpu to use +- `return_all_scores`: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'. +- `task`: 'text-classification' or 'zero-shot-classification' +- `labels`: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g., +["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is " sequence to +classify This example is LABEL . " and the model predicts whether that sequence is a contradiction +or an entailment. + + +#### predict + +```python + | predict(documents: List[Document]) -> List[Document] +``` + +Returns documents containing classification result in meta field + +**Arguments**: + +- `documents`: List of Document to classify + +**Returns**: + +List of Document enriched with meta information + diff --git a/docs/_src/api/api/generate_docstrings.sh b/docs/_src/api/api/generate_docstrings.sh index f488e60a2..88e89cedd 100755 --- a/docs/_src/api/api/generate_docstrings.sh +++ b/docs/_src/api/api/generate_docstrings.sh @@ -17,4 +17,5 @@ pydoc-markdown pydoc-markdown-graph-retriever.yml pydoc-markdown pydoc-markdown-evaluation.yml pydoc-markdown pydoc-markdown-ranker.yml pydoc-markdown pydoc-markdown-question-generator.yml +pydoc-markdown pydoc-markdown-document-classifier.yml diff --git a/docs/_src/api/api/pydoc-markdown-document-classifier.yml b/docs/_src/api/api/pydoc-markdown-document-classifier.yml new file mode 100644 index 000000000..8d5bf04c8 --- /dev/null +++ b/docs/_src/api/api/pydoc-markdown-document-classifier.yml @@ -0,0 +1,18 @@ +loaders: + - type: python + search_path: [../../../../haystack/document_classifier] + modules: ['base', 'transformers'] + ignore_when_discovered: ['__init__'] +processor: + - type: filter + expression: not name.startswith('_') and default() + - documented_only: true + - do_not_filter_modules: false + - skip_empty_modules: true +renderer: + type: markdown + descriptive_class_title: true + descriptive_module_title: true + add_method_class_prefix: false + add_member_class_prefix: false + filename: document_classifier.md diff --git a/haystack/document_classifier/__init__.py b/haystack/document_classifier/__init__.py new file mode 100644 index 000000000..00b2aef7d --- /dev/null +++ b/haystack/document_classifier/__init__.py @@ -0,0 +1 @@ +from haystack.document_classifier.transformers import TransformersDocumentClassifier diff --git a/haystack/document_classifier/base.py b/haystack/document_classifier/base.py new file mode 100644 index 000000000..5fb974776 --- /dev/null +++ b/haystack/document_classifier/base.py @@ -0,0 +1,57 @@ +import logging +from abc import abstractmethod +from typing import List +from functools import wraps +from time import perf_counter + + +from haystack import Document, BaseComponent + +logger = logging.getLogger(__name__) + + +class BaseDocumentClassifier(BaseComponent): + outgoing_edges = 1 + query_count = 0 + query_time = 0 + + @abstractmethod + def predict(self, documents: List[Document]): + pass + + def run(self, query: str, documents: List[Document]): # type: ignore + self.query_count += 1 + if documents: + predict = self.timing(self.predict, "query_time") + results = predict(documents=documents) + else: + results = [] + + document_ids = [doc.id for doc in results] + logger.debug(f"Retrieved documents with IDs: {document_ids}") + output = {"documents": results} + + return output, "output_1" + + def timing(self, fn, attr_name): + """Wrapper method used to time functions. """ + @wraps(fn) + def wrapper(*args, **kwargs): + if attr_name not in self.__dict__: + self.__dict__[attr_name] = 0 + tic = perf_counter() + ret = fn(*args, **kwargs) + toc = perf_counter() + self.__dict__[attr_name] += toc - tic + return ret + return wrapper + + def print_time(self): + print("Classifier (Speed)") + print("---------------") + if not self.query_count: + print("No querying performed via Classifier.run()") + else: + print(f"Queries Performed: {self.query_count}") + print(f"Query time: {self.query_time}s") + print(f"{self.query_time / self.query_count} seconds per query") \ No newline at end of file diff --git a/haystack/document_classifier/transformers.py b/haystack/document_classifier/transformers.py new file mode 100644 index 000000000..fb6ecd723 --- /dev/null +++ b/haystack/document_classifier/transformers.py @@ -0,0 +1,121 @@ +import logging +from typing import List, Optional + +from transformers import pipeline + +from haystack import Document +from haystack.document_classifier.base import BaseDocumentClassifier + +logger = logging.getLogger(__name__) + + +class TransformersDocumentClassifier(BaseDocumentClassifier): + """ + Transformer based model for document classification using the HuggingFace's transformers framework + (https://github.com/huggingface/transformers). + While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same. + This node classifies documents and adds the output from the classification step to the document's meta data. + The meta field of the document is a dictionary with the following format: + 'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} } + + With this document_classifier, you can directly get predictions via predict() + + Usage example: + ... + retriever = ElasticsearchRetriever(document_store=document_store) + document_classifier = TransformersDocumentClassifier(model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion") + p = Pipeline() + p.add_node(component=retriever, name="Retriever", inputs=["Query"]) + p.add_node(component=document_classifier, name="Classifier", inputs=["Retriever"]) + res = p.run( + query="Who is the father of Arya Stark?", + params={"Retriever": {"top_k": 10}} + ) + + # print the classification results + print_documents(res, max_text_len=100, print_meta=True) + # or access the predicted class label directly + res["documents"][0].to_dict()["meta"]["classification"]["label"] + """ + + def __init__( + self, + model_name_or_path: str = "bhadresh-savani/distilbert-base-uncased-emotion", + model_version: Optional[str] = None, + tokenizer: Optional[str] = None, + use_gpu: int = 0, + return_all_scores: bool = False, + task: str = 'text-classification', + labels: Optional[List[str]] = None + ): + """ + Load a text classification model from Transformers. + Available models for the task of text-classification include: + - ``'bhadresh-savani/distilbert-base-uncased-emotion'`` + - ``'Hate-speech-CNERG/dehatebert-mono-english'`` + + Available models for the task of zero-shot-classification include: + - ``'valhalla/distilbart-mnli-12-3'`` + - ``'cross-encoder/nli-distilroberta-base'`` + + See https://huggingface.co/models for full list of available models. + Filter for text classification models: https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads + Filter for zero-shot classification models (NLI): https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli + + :param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bhadresh-savani/distilbert-base-uncased-emotion'. + See https://huggingface.co/models for full list of available models. + :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. + :param tokenizer: Name of the tokenizer (usually the same as model) + :param use_gpu: If < 0, then use cpu. If >= 0, this is the ordinal of the gpu to use + :param return_all_scores: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'. + :param task: 'text-classification' or 'zero-shot-classification' + :param labels: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g., + ["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is " sequence to + classify This example is LABEL . " and the model predicts whether that sequence is a contradiction + or an entailment. + + """ + + # save init parameters to enable export of component config as YAML + self.set_config( + model_name_or_path=model_name_or_path, model_version=model_version, tokenizer=tokenizer, + use_gpu=use_gpu, return_all_scores=return_all_scores, labels=labels, task=task + ) + if labels and task == 'text-classification': + logger.warning(f'Provided labels {labels} will be ignored for task text-classification. Set task to ' + f'zero-shot-classification to use labels.') + + if tokenizer is None: + tokenizer = model_name_or_path + if task == 'zero-shot-classification': + self.model = pipeline(task=task, model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version) + elif task == 'text-classification': + self.model = pipeline(task=task, model=model_name_or_path, tokenizer=tokenizer, device=use_gpu, revision=model_version, return_all_scores=return_all_scores) + self.return_all_scores = return_all_scores + self.labels = labels + self.task = task + + def predict(self, documents: List[Document]) -> List[Document]: + """ + Returns documents containing classification result in meta field + + :param documents: List of Document to classify + :return: List of Document enriched with meta information + + """ + texts = [doc.text for doc in documents] + if self.task == 'zero-shot-classification': + predictions = self.model(texts, candidate_labels=self.labels, truncation=True) + elif self.task == 'text-classification': + predictions = self.model(texts, return_all_scores=self.return_all_scores, truncation=True) + + classified_docs: List[Document] = [] + + for prediction, doc in zip(predictions, documents): + cur_doc = doc + cur_doc.meta["classification"] = prediction + if self.task == 'zero-shot-classification': + cur_doc.meta["classification"]["label"] = cur_doc.meta["classification"]["labels"][0] + classified_docs.append(cur_doc) + + return classified_docs diff --git a/haystack/utils.py b/haystack/utils.py index d1a4b889d..60dca8b1c 100644 --- a/haystack/utils.py +++ b/haystack/utils.py @@ -4,7 +4,7 @@ from itertools import islice import logging import pprint import pandas as pd -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional from haystack.document_store.sql import DocumentORM import subprocess import time @@ -134,8 +134,7 @@ def print_answers(results: dict, details: str = "all"): pp.pprint(results) - -def print_documents(results: dict, max_text_len: int=None): +def print_documents(results: dict, max_text_len: Optional[int] = None, print_meta: bool = False): print(f"Query: {results['query']}") pp = pprint.PrettyPrinter(indent=4) for d in results["documents"]: @@ -147,6 +146,8 @@ def print_documents(results: dict, max_text_len: int=None): "name": d["meta"]["name"], "text": new_text } + if print_meta: + results["meta"] = d["meta"] pp.pprint(results) diff --git a/test/conftest.py b/test/conftest.py index 57b1c5169..e91234323 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -19,6 +19,7 @@ from haystack.document_store.milvus import MilvusDocumentStore from haystack.generator.transformers import RAGenerator, RAGeneratorType from haystack.modeling.infer import Inferencer, QAInferencer from haystack.ranker import SentenceTransformersRanker +from haystack.document_classifier.transformers import TransformersDocumentClassifier from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever @@ -337,6 +338,23 @@ def ranker(): ) +@pytest.fixture(scope="module") +def document_classifier(): + return TransformersDocumentClassifier( + model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", + use_gpu=-1 + ) + +@pytest.fixture(scope="module") +def zero_shot_document_classifier(): + return TransformersDocumentClassifier( + model_name_or_path="cross-encoder/nli-distilroberta-base", + use_gpu=-1, + task="zero-shot-classification", + labels=["negative", "positive"] + ) + + # TODO Fix bug in test_no_answer_output when using # @pytest.fixture(params=["farm", "transformers"]) @pytest.fixture(params=["farm"], scope="module") diff --git a/test/test_document_classifier.py b/test/test_document_classifier.py new file mode 100644 index 000000000..2dfc09254 --- /dev/null +++ b/test/test_document_classifier.py @@ -0,0 +1,48 @@ +import pytest + +from haystack import Document +from haystack.document_classifier.base import BaseDocumentClassifier + + +@pytest.mark.slow +def test_document_classifier(document_classifier): + assert isinstance(document_classifier, BaseDocumentClassifier) + + docs = [ + Document( + text="""That's good. I like it."""*700, # extra long text to check truncation + meta={"name": "0"}, + id="1", + ), + Document( + text="""That's bad. I don't like it.""", + meta={"name": "1"}, + id="2", + ), + ] + results = document_classifier.predict(documents=docs) + expected_labels = ["joy", "sadness"] + for i, doc in enumerate(results): + assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i] + + +@pytest.mark.slow +def test_zero_shot_document_classifier(zero_shot_document_classifier): + assert isinstance(zero_shot_document_classifier, BaseDocumentClassifier) + + docs = [ + Document( + text="""That's good. I like it."""*700, # extra long text to check truncation + meta={"name": "0"}, + id="1", + ), + Document( + text="""That's bad. I don't like it.""", + meta={"name": "1"}, + id="2", + ), + ] + results = zero_shot_document_classifier.predict(documents=docs) + expected_labels = ["positive", "negative"] + for i, doc in enumerate(results): + assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]