From 89247b804cc4c2ebdf7d45be2722bd88bfeead69 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:16:03 +0200 Subject: [PATCH] refactor: make `TransformersDocumentClassifier` output consistent between different types of classification (#3224) * make output consistent * make output consistent * added tests for details * better tests * Update test_document_classifier.py * make black happy * Update test_document_classifier.py * Update test_document_classifier.py --- docs/_src/api/api/document_classifier.md | 6 ++-- .../haystack-pipeline-main.schema.json | 15 ++++++--- .../nodes/document_classifier/transformers.py | 33 ++++++++++++------- test/conftest.py | 2 +- test/document_stores/test_faiss_and_milvus.py | 2 +- test/nodes/test_document_classifier.py | 20 +++++++++++ 6 files changed, 57 insertions(+), 21 deletions(-) diff --git a/docs/_src/api/api/document_classifier.md b/docs/_src/api/api/document_classifier.md index acae37f8a..18afebfd0 100644 --- a/docs/_src/api/api/document_classifier.md +++ b/docs/_src/api/api/document_classifier.md @@ -37,7 +37,7 @@ Transformer based model for document classification using the HuggingFace's tran While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same. This node classifies documents and adds the output from the classification step to the document's meta data. The meta field of the document is a dictionary with the following format: -``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }`` +``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'love', 'score': 0.960899, 'details': {'love': 0.960899, 'joy': 0.032584, ...}}}`` Classification is run on document's content field by default. If you want it to run on another field, set the `classification_field` to one of document's meta fields. @@ -89,7 +89,7 @@ def __init__(model_name_or_path: model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, - return_all_scores: bool = False, + top_k: Optional[int] = 1, task: str = "text-classification", labels: Optional[List[str]] = None, batch_size: int = 16, @@ -120,7 +120,7 @@ See https://huggingface.co/models for full list of available models. - `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. - `tokenizer`: Name of the tokenizer (usually the same as model) - `use_gpu`: Whether to use GPU (if available). -- `return_all_scores`: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'. +- `top_k`: The number of top predictions to return. The default is 1. Enter None to return all the predictions. Only used for task 'text-classification'. - `task`: 'text-classification' or 'zero-shot-classification' - `labels`: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g., ["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is " sequence to diff --git a/haystack/json-schemas/haystack-pipeline-main.schema.json b/haystack/json-schemas/haystack-pipeline-main.schema.json index b274b3f17..f00c87f82 100644 --- a/haystack/json-schemas/haystack-pipeline-main.schema.json +++ b/haystack/json-schemas/haystack-pipeline-main.schema.json @@ -5823,10 +5823,17 @@ "default": true, "type": "boolean" }, - "return_all_scores": { - "title": "Return All Scores", - "default": false, - "type": "boolean" + "top_k": { + "title": "Top K", + "default": 1, + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ] }, "task": { "title": "Task", diff --git a/haystack/nodes/document_classifier/transformers.py b/haystack/nodes/document_classifier/transformers.py index 3c76da2dc..d67e0c553 100644 --- a/haystack/nodes/document_classifier/transformers.py +++ b/haystack/nodes/document_classifier/transformers.py @@ -21,7 +21,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier): While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same. This node classifies documents and adds the output from the classification step to the document's meta data. The meta field of the document is a dictionary with the following format: - ``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }`` + ``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'love', 'score': 0.960899, 'details': {'love': 0.960899, 'joy': 0.032584, ...}}}`` Classification is run on document's content field by default. If you want it to run on another field, set the `classification_field` to one of document's meta fields. @@ -70,7 +70,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier): model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, - return_all_scores: bool = False, + top_k: Optional[int] = 1, task: str = "text-classification", labels: Optional[List[str]] = None, batch_size: int = 16, @@ -98,7 +98,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier): :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :param tokenizer: Name of the tokenizer (usually the same as model) :param use_gpu: Whether to use GPU (if available). - :param return_all_scores: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'. + :param top_k: The number of top predictions to return. The default is 1. Enter None to return all the predictions. Only used for task 'text-classification'. :param task: 'text-classification' or 'zero-shot-classification' :param labels: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g., ["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is " sequence to @@ -150,10 +150,10 @@ class TransformersDocumentClassifier(BaseDocumentClassifier): tokenizer=tokenizer, device=resolved_devices[0], revision=model_version, - return_all_scores=return_all_scores, + top_k=top_k, use_auth_token=use_auth_token, ) - self.return_all_scores = return_all_scores + self.top_k = top_k self.labels = labels self.task = task self.batch_size = batch_size @@ -177,22 +177,31 @@ class TransformersDocumentClassifier(BaseDocumentClassifier): for doc in documents ] batches = self.get_batches(texts, batch_size=batch_size) - batched_predictions = [] - pb = tqdm(total=len(texts), disable=not self.progress_bar, desc="Generating questions") + predictions = [] + pb = tqdm(total=len(texts), disable=not self.progress_bar, desc="Classifying documents") for batch in batches: if self.task == "zero-shot-classification": batched_prediction = self.model(batch, candidate_labels=self.labels, truncation=True) elif self.task == "text-classification": - batched_prediction = self.model(batch, return_all_scores=self.return_all_scores, truncation=True) - batched_predictions.append(batched_prediction) + batched_prediction = self.model(batch, top_k=self.top_k, truncation=True) + predictions.extend(batched_prediction) pb.update(len(batch)) pb.close() - predictions = [pred for batched_prediction in batched_predictions for pred in batched_prediction] for prediction, doc in zip(predictions, documents): if self.task == "zero-shot-classification": - prediction["label"] = prediction["labels"][0] - doc.meta["classification"] = prediction + formatted_prediction = { + "label": prediction["labels"][0], + "score": prediction["scores"][0], + "details": {label: score for label, score in zip(prediction["labels"], prediction["scores"])}, + } + elif self.task == "text-classification": + formatted_prediction = { + "label": prediction[0]["label"], + "score": prediction[0]["score"], + "details": {el["label"]: el["score"] for el in prediction}, + } + doc.meta["classification"] = formatted_prediction return documents diff --git a/test/conftest.py b/test/conftest.py index 615ff9edd..5648fc261 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -621,7 +621,7 @@ def ranker(): @pytest.fixture def document_classifier(): return TransformersDocumentClassifier( - model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False + model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False, top_k=2 ) diff --git a/test/document_stores/test_faiss_and_milvus.py b/test/document_stores/test_faiss_and_milvus.py index 6a9714ca6..cf0bc4a78 100644 --- a/test/document_stores/test_faiss_and_milvus.py +++ b/test/document_stores/test_faiss_and_milvus.py @@ -16,7 +16,7 @@ from haystack.document_stores.faiss import FAISSDocumentStore from haystack.pipelines import Pipeline from haystack.nodes.retriever.dense import EmbeddingRetriever -from ..conftest import document_classifier, ensure_ids_are_correct_uuids, SAMPLES_PATH, MockDenseRetriever +from ..conftest import ensure_ids_are_correct_uuids, SAMPLES_PATH, MockDenseRetriever DOCUMENTS = [ diff --git a/test/nodes/test_document_classifier.py b/test/nodes/test_document_classifier.py index d2b5634ed..1a374e8ee 100644 --- a/test/nodes/test_document_classifier.py +++ b/test/nodes/test_document_classifier.py @@ -22,6 +22,16 @@ def test_document_classifier(document_classifier): assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i] +@pytest.mark.integration +def test_document_classifier_details(document_classifier): + + docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")] + results = document_classifier.predict(documents=docs) + for doc in results: + assert "details" in doc.meta["classification"] + assert len(doc.meta["classification"]["details"]) == 2 # top_k = 2 + + @pytest.mark.integration def test_document_classifier_batch_single_doc_list(document_classifier): docs = [ @@ -65,6 +75,16 @@ def test_zero_shot_document_classifier(zero_shot_document_classifier): assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i] +@pytest.mark.integration +def test_zero_shot_document_classifier_details(zero_shot_document_classifier): + + docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")] + results = zero_shot_document_classifier.predict(documents=docs) + for doc in results: + assert "details" in doc.meta["classification"] + assert len(doc.meta["classification"]["details"]) == 2 # n_labels = 2 + + @pytest.mark.integration def test_document_classifier_batch_size(batched_document_classifier): assert isinstance(batched_document_classifier, BaseDocumentClassifier)