mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
refactor: make TransformersDocumentClassifier output consistent between different types of classification (#3224)
* make output consistent * make output consistent * added tests for details * better tests * Update test_document_classifier.py * make black happy * Update test_document_classifier.py * Update test_document_classifier.py
This commit is contained in:
parent
15bb6c2ea2
commit
89247b804c
@ -37,7 +37,7 @@ Transformer based model for document classification using the HuggingFace's tran
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
|
||||
This node classifies documents and adds the output from the classification step to the document's meta data.
|
||||
The meta field of the document is a dictionary with the following format:
|
||||
``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }``
|
||||
``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'love', 'score': 0.960899, 'details': {'love': 0.960899, 'joy': 0.032584, ...}}}``
|
||||
|
||||
Classification is run on document's content field by default. If you want it to run on another field,
|
||||
set the `classification_field` to one of document's meta fields.
|
||||
@ -89,7 +89,7 @@ def __init__(model_name_or_path:
|
||||
model_version: Optional[str] = None,
|
||||
tokenizer: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
return_all_scores: bool = False,
|
||||
top_k: Optional[int] = 1,
|
||||
task: str = "text-classification",
|
||||
labels: Optional[List[str]] = None,
|
||||
batch_size: int = 16,
|
||||
@ -120,7 +120,7 @@ See https://huggingface.co/models for full list of available models.
|
||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
- `tokenizer`: Name of the tokenizer (usually the same as model)
|
||||
- `use_gpu`: Whether to use GPU (if available).
|
||||
- `return_all_scores`: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'.
|
||||
- `top_k`: The number of top predictions to return. The default is 1. Enter None to return all the predictions. Only used for task 'text-classification'.
|
||||
- `task`: 'text-classification' or 'zero-shot-classification'
|
||||
- `labels`: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g.,
|
||||
["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is "<cls> sequence to
|
||||
|
||||
@ -5823,10 +5823,17 @@
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
},
|
||||
"return_all_scores": {
|
||||
"title": "Return All Scores",
|
||||
"default": false,
|
||||
"type": "boolean"
|
||||
"top_k": {
|
||||
"title": "Top K",
|
||||
"default": 1,
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "integer"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"task": {
|
||||
"title": "Task",
|
||||
|
||||
@ -21,7 +21,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
|
||||
This node classifies documents and adds the output from the classification step to the document's meta data.
|
||||
The meta field of the document is a dictionary with the following format:
|
||||
``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'neutral', 'probability' = 0.9997646, ...} }``
|
||||
``'meta': {'name': '450_Baelor.txt', 'classification': {'label': 'love', 'score': 0.960899, 'details': {'love': 0.960899, 'joy': 0.032584, ...}}}``
|
||||
|
||||
Classification is run on document's content field by default. If you want it to run on another field,
|
||||
set the `classification_field` to one of document's meta fields.
|
||||
@ -70,7 +70,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
model_version: Optional[str] = None,
|
||||
tokenizer: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
return_all_scores: bool = False,
|
||||
top_k: Optional[int] = 1,
|
||||
task: str = "text-classification",
|
||||
labels: Optional[List[str]] = None,
|
||||
batch_size: int = 16,
|
||||
@ -98,7 +98,7 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param tokenizer: Name of the tokenizer (usually the same as model)
|
||||
:param use_gpu: Whether to use GPU (if available).
|
||||
:param return_all_scores: Whether to return all prediction scores or just the one of the predicted class. Only used for task 'text-classification'.
|
||||
:param top_k: The number of top predictions to return. The default is 1. Enter None to return all the predictions. Only used for task 'text-classification'.
|
||||
:param task: 'text-classification' or 'zero-shot-classification'
|
||||
:param labels: Only used for task 'zero-shot-classification'. List of string defining class labels, e.g.,
|
||||
["positive", "negative"] otherwise None. Given a LABEL, the sequence fed to the model is "<cls> sequence to
|
||||
@ -150,10 +150,10 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
tokenizer=tokenizer,
|
||||
device=resolved_devices[0],
|
||||
revision=model_version,
|
||||
return_all_scores=return_all_scores,
|
||||
top_k=top_k,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
self.return_all_scores = return_all_scores
|
||||
self.top_k = top_k
|
||||
self.labels = labels
|
||||
self.task = task
|
||||
self.batch_size = batch_size
|
||||
@ -177,22 +177,31 @@ class TransformersDocumentClassifier(BaseDocumentClassifier):
|
||||
for doc in documents
|
||||
]
|
||||
batches = self.get_batches(texts, batch_size=batch_size)
|
||||
batched_predictions = []
|
||||
pb = tqdm(total=len(texts), disable=not self.progress_bar, desc="Generating questions")
|
||||
predictions = []
|
||||
pb = tqdm(total=len(texts), disable=not self.progress_bar, desc="Classifying documents")
|
||||
for batch in batches:
|
||||
if self.task == "zero-shot-classification":
|
||||
batched_prediction = self.model(batch, candidate_labels=self.labels, truncation=True)
|
||||
elif self.task == "text-classification":
|
||||
batched_prediction = self.model(batch, return_all_scores=self.return_all_scores, truncation=True)
|
||||
batched_predictions.append(batched_prediction)
|
||||
batched_prediction = self.model(batch, top_k=self.top_k, truncation=True)
|
||||
predictions.extend(batched_prediction)
|
||||
pb.update(len(batch))
|
||||
pb.close()
|
||||
predictions = [pred for batched_prediction in batched_predictions for pred in batched_prediction]
|
||||
|
||||
for prediction, doc in zip(predictions, documents):
|
||||
if self.task == "zero-shot-classification":
|
||||
prediction["label"] = prediction["labels"][0]
|
||||
doc.meta["classification"] = prediction
|
||||
formatted_prediction = {
|
||||
"label": prediction["labels"][0],
|
||||
"score": prediction["scores"][0],
|
||||
"details": {label: score for label, score in zip(prediction["labels"], prediction["scores"])},
|
||||
}
|
||||
elif self.task == "text-classification":
|
||||
formatted_prediction = {
|
||||
"label": prediction[0]["label"],
|
||||
"score": prediction[0]["score"],
|
||||
"details": {el["label"]: el["score"] for el in prediction},
|
||||
}
|
||||
doc.meta["classification"] = formatted_prediction
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@ -621,7 +621,7 @@ def ranker():
|
||||
@pytest.fixture
|
||||
def document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False
|
||||
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", use_gpu=False, top_k=2
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from haystack.document_stores.faiss import FAISSDocumentStore
|
||||
from haystack.pipelines import Pipeline
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
|
||||
from ..conftest import document_classifier, ensure_ids_are_correct_uuids, SAMPLES_PATH, MockDenseRetriever
|
||||
from ..conftest import ensure_ids_are_correct_uuids, SAMPLES_PATH, MockDenseRetriever
|
||||
|
||||
|
||||
DOCUMENTS = [
|
||||
|
||||
@ -22,6 +22,16 @@ def test_document_classifier(document_classifier):
|
||||
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_document_classifier_details(document_classifier):
|
||||
|
||||
docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")]
|
||||
results = document_classifier.predict(documents=docs)
|
||||
for doc in results:
|
||||
assert "details" in doc.meta["classification"]
|
||||
assert len(doc.meta["classification"]["details"]) == 2 # top_k = 2
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_document_classifier_batch_single_doc_list(document_classifier):
|
||||
docs = [
|
||||
@ -65,6 +75,16 @@ def test_zero_shot_document_classifier(zero_shot_document_classifier):
|
||||
assert doc.to_dict()["meta"]["classification"]["label"] == expected_labels[i]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_zero_shot_document_classifier_details(zero_shot_document_classifier):
|
||||
|
||||
docs = [Document(content="""That's good. I like it."""), Document(content="""That's bad. I don't like it.""")]
|
||||
results = zero_shot_document_classifier.predict(documents=docs)
|
||||
for doc in results:
|
||||
assert "details" in doc.meta["classification"]
|
||||
assert len(doc.meta["classification"]["details"]) == 2 # n_labels = 2
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_document_classifier_batch_size(batched_document_classifier):
|
||||
assert isinstance(batched_document_classifier, BaseDocumentClassifier)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user