mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 03:09:28 +00:00
feat: Extend TransformersQueryClassifier: clean version (#2965)
* extend query classifier in one commit * variable number of outgoing edges * improve tests * fix unused import * lightweight approach * fix _calculate_outgoing_edges * remove duplicate label validation * Remove print
This commit is contained in:
parent
c91316e862
commit
4a63484916
@ -96,10 +96,11 @@ queries or statement vs question queries.
|
||||
class TransformersQueryClassifier(BaseQueryClassifier)
|
||||
```
|
||||
|
||||
A node to classify an incoming query into one of two categories using a (small) BERT transformer model.
|
||||
A node to classify an incoming query into categories using a transformer model.
|
||||
Depending on the result, the query flows to a different branch in your pipeline and the further processing
|
||||
can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2`
|
||||
can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
|
||||
from this node.
|
||||
This node also supports zero-shot-classification.
|
||||
|
||||
**Example**:
|
||||
|
||||
@ -120,7 +121,7 @@ from this node.
|
||||
|
||||
Models:
|
||||
|
||||
Pass your own `Transformer` binary classification model from file/huggingface or use one of the following
|
||||
Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following
|
||||
pretrained ones hosted on Huggingface:
|
||||
1) Keywords vs. Questions/Statements (Default)
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
||||
@ -143,13 +144,20 @@ from this node.
|
||||
#### TransformersQueryClassifier.\_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", use_gpu: bool = True, batch_size: int = 16, progress_bar: bool = True)
|
||||
def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = DEFAULT_LABELS, batch_size: int = 16, progress_bar: bool = True)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `model_name_or_path`: Transformer based fine tuned mini bert model for query classification
|
||||
- `model_name_or_path`: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'.
|
||||
See [Hugging Face models](https://huggingface.co/models) for a full list of available models.
|
||||
- `model_version`: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash.
|
||||
- `tokenizer`: The name of the tokenizer (usually the same as model).
|
||||
- `use_gpu`: Whether to use GPU (if available).
|
||||
- `batch_size`: Batch size for inference.
|
||||
- `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
|
||||
- `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
|
||||
the second label to output_2, and so on. The labels must match the model labels; only the order can differ.
|
||||
If the task is 'zero-shot-classification', these are the candidate labels.
|
||||
- `batch_size`: The number of queries to be processed at a time.
|
||||
- `progress_bar`: Whether to show a progress bar.
|
||||
|
||||
|
||||
@ -4801,11 +4801,35 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"model_version": {
|
||||
"title": "Model Version",
|
||||
"type": "string"
|
||||
},
|
||||
"tokenizer": {
|
||||
"title": "Tokenizer",
|
||||
"type": "string"
|
||||
},
|
||||
"use_gpu": {
|
||||
"title": "Use Gpu",
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
},
|
||||
"task": {
|
||||
"title": "Task",
|
||||
"default": "text-classification",
|
||||
"type": "string"
|
||||
},
|
||||
"labels": {
|
||||
"title": "Labels",
|
||||
"default": [
|
||||
"LABEL_1",
|
||||
"LABEL_0"
|
||||
],
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"title": "Batch Size",
|
||||
"default": 16,
|
||||
|
||||
@ -1,22 +1,27 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Dict
|
||||
from typing import Union, List, Optional, Dict, Any
|
||||
|
||||
from transformers import pipeline
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
|
||||
|
||||
# from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
|
||||
from haystack.nodes.query_classifier.base import BaseQueryClassifier
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.utils.torch_utils import ListDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_LABELS = ["LABEL_1", "LABEL_0"]
|
||||
|
||||
|
||||
class TransformersQueryClassifier(BaseQueryClassifier):
|
||||
"""
|
||||
A node to classify an incoming query into one of two categories using a (small) BERT transformer model.
|
||||
A node to classify an incoming query into categories using a transformer model.
|
||||
Depending on the result, the query flows to a different branch in your pipeline and the further processing
|
||||
can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2`
|
||||
can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
|
||||
from this node.
|
||||
This node also supports zero-shot-classification.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@ -36,7 +41,7 @@ class TransformersQueryClassifier(BaseQueryClassifier):
|
||||
|
||||
Models:
|
||||
|
||||
Pass your own `Transformer` binary classification model from file/huggingface or use one of the following
|
||||
Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following
|
||||
pretrained ones hosted on Huggingface:
|
||||
1) Keywords vs. Questions/Statements (Default)
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
||||
@ -58,56 +63,98 @@ class TransformersQueryClassifier(BaseQueryClassifier):
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection",
|
||||
model_version: Optional[str] = None,
|
||||
tokenizer: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
task: str = "text-classification",
|
||||
labels: List[str] = DEFAULT_LABELS,
|
||||
batch_size: int = 16,
|
||||
progress_bar: bool = True,
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: Transformer based fine tuned mini bert model for query classification
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'.
|
||||
See [Hugging Face models](https://huggingface.co/models) for a full list of available models.
|
||||
:param model_version: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash.
|
||||
:param tokenizer: The name of the tokenizer (usually the same as model).
|
||||
:param use_gpu: Whether to use GPU (if available).
|
||||
:param batch_size: Batch size for inference.
|
||||
:param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
|
||||
:param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
|
||||
the second label to output_2, and so on. The labels must match the model labels; only the order can differ.
|
||||
If the task is 'zero-shot-classification', these are the candidate labels.
|
||||
:param batch_size: The number of queries to be processed at a time.
|
||||
:param progress_bar: Whether to show a progress bar.
|
||||
"""
|
||||
super().__init__()
|
||||
devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||
device = 0 if devices[0].type == "cuda" else -1
|
||||
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu)
|
||||
self.model = pipeline(
|
||||
task=task, model=model_name_or_path, tokenizer=tokenizer, device=device, revision=model_version
|
||||
)
|
||||
|
||||
self.labels = labels
|
||||
if task == "text-classification":
|
||||
labels_from_model = [label for label in self.model.model.config.id2label.values()]
|
||||
if set(labels) != set(labels_from_model):
|
||||
raise ValueError(
|
||||
f"For text-classification, the provided labels must match the model labels; only the order can differ.\n"
|
||||
f"Provided labels: {labels}\n"
|
||||
f"Model labels: {labels_from_model}"
|
||||
)
|
||||
if task not in ["text-classification", "zero-shot-classification"]:
|
||||
raise ValueError(
|
||||
f"Task not supported: {task}.\n"
|
||||
f"Possible task values are: 'text-classification' or 'zero-shot-classification'"
|
||||
)
|
||||
self.task = task
|
||||
self.batch_size = batch_size
|
||||
device = 0 if self.devices[0].type == "cuda" else -1
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
@classmethod
|
||||
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||
labels = component_params.get("labels", DEFAULT_LABELS)
|
||||
if labels is None or len(labels) == 0:
|
||||
raise ValueError("The labels must be provided")
|
||||
return len(labels)
|
||||
|
||||
self.query_classification_pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=device)
|
||||
def _get_edge_number_from_label(self, label):
|
||||
return self.labels.index(label) + 1
|
||||
|
||||
def run(self, query):
|
||||
is_question: bool = self.query_classification_pipeline(query)[0]["label"] == "LABEL_1"
|
||||
|
||||
if is_question:
|
||||
return {}, "output_1"
|
||||
else:
|
||||
return {}, "output_2"
|
||||
def run(self, query: str): # type: ignore
|
||||
if self.task == "zero-shot-classification":
|
||||
prediction = self.model([query], candidate_labels=self.labels, truncation=True)
|
||||
label = prediction[0]["labels"][0]
|
||||
elif self.task == "text-classification":
|
||||
prediction = self.model([query], truncation=True)
|
||||
label = prediction[0]["label"]
|
||||
return {}, f"output_{self._get_edge_number_from_label(label)}"
|
||||
|
||||
def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
split: Dict[str, Dict[str, List]] = {"output_1": {"queries": []}, "output_2": {"queries": []}}
|
||||
|
||||
# HF pb hack https://discuss.huggingface.co/t/progress-bar-for-hf-pipelines/20498/2
|
||||
queries_dataset = ListDataset(queries)
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
all_predictions = []
|
||||
for predictions in tqdm(
|
||||
self.query_classification_pipeline(queries_dataset, batch_size=batch_size),
|
||||
disable=not self.progress_bar,
|
||||
desc="Classifying queries",
|
||||
):
|
||||
all_predictions.extend(predictions)
|
||||
if self.task == "zero-shot-classification":
|
||||
for predictions in tqdm(
|
||||
self.model(queries_dataset, candidate_labels=self.labels, truncation=True, batch_size=batch_size),
|
||||
disable=not self.progress_bar,
|
||||
desc="Classifying queries",
|
||||
):
|
||||
all_predictions.extend([predictions])
|
||||
elif self.task == "text-classification":
|
||||
for predictions in tqdm(
|
||||
self.model(queries_dataset, truncation=True, batch_size=batch_size),
|
||||
disable=not self.progress_bar,
|
||||
desc="Classifying queries",
|
||||
):
|
||||
all_predictions.extend([predictions])
|
||||
results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore
|
||||
for query, prediction in zip(queries, all_predictions):
|
||||
if self.task == "zero-shot-classification":
|
||||
label = prediction["labels"][0]
|
||||
elif self.task == "text-classification":
|
||||
label = prediction["label"]
|
||||
results[f"output_{self._get_edge_number_from_label(label)}"]["queries"].append(query)
|
||||
|
||||
for query, pred in zip(queries, all_predictions):
|
||||
if pred["label"] == "LABEL_1":
|
||||
split["output_1"]["queries"].append(query)
|
||||
else:
|
||||
split["output_2"]["queries"].append(query)
|
||||
|
||||
return split, "split"
|
||||
return results, "split"
|
||||
|
||||
94
test/nodes/test_query_classifier.py
Normal file
94
test/nodes/test_query_classifier.py
Normal file
@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transformers_query_classifier():
|
||||
return TransformersQueryClassifier(
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||
use_gpu=False,
|
||||
task="text-classification",
|
||||
labels=["LABEL_1", "LABEL_0"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zero_shot_transformers_query_classifier():
|
||||
return TransformersQueryClassifier(
|
||||
model_name_or_path="typeform/distilbert-base-uncased-mnli",
|
||||
use_gpu=False,
|
||||
task="zero-shot-classification",
|
||||
labels=["happy", "unhappy", "neutral"],
|
||||
)
|
||||
|
||||
|
||||
def test_transformers_query_classifier(transformers_query_classifier):
|
||||
output = transformers_query_classifier.run(query="morse code")
|
||||
assert output == ({}, "output_2")
|
||||
|
||||
output = transformers_query_classifier.run(query="How old is John?")
|
||||
assert output == ({}, "output_1")
|
||||
|
||||
|
||||
def test_transformers_query_classifier_batch(transformers_query_classifier):
|
||||
queries = ["morse code", "How old is John?"]
|
||||
output = transformers_query_classifier.run_batch(queries=queries)
|
||||
|
||||
assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}}
|
||||
|
||||
|
||||
def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier):
|
||||
output = zero_shot_transformers_query_classifier.run(query="What's the answer?")
|
||||
assert output == ({}, "output_3")
|
||||
|
||||
output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?")
|
||||
assert output == ({}, "output_1")
|
||||
|
||||
output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??")
|
||||
assert output == ({}, "output_2")
|
||||
|
||||
|
||||
def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier):
|
||||
queries = [
|
||||
"What's the answer?",
|
||||
"Would you be so kind to tell me the answer?",
|
||||
"Can you give me the right answer for once??",
|
||||
]
|
||||
|
||||
output = zero_shot_transformers_query_classifier.run_batch(queries=queries)
|
||||
|
||||
assert output[0] == {
|
||||
"output_3": {"queries": ["What's the answer?"]},
|
||||
"output_1": {"queries": ["Would you be so kind to tell me the answer?"]},
|
||||
"output_2": {"queries": ["Can you give me the right answer for once??"]},
|
||||
}
|
||||
|
||||
|
||||
def test_transformers_query_classifier_wrong_labels():
|
||||
with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"):
|
||||
query_classifier = TransformersQueryClassifier(
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||
use_gpu=False,
|
||||
task="text-classification",
|
||||
labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"],
|
||||
)
|
||||
|
||||
|
||||
def test_transformers_query_classifier_no_labels():
|
||||
with pytest.raises(ValueError, match="The labels must be provided"):
|
||||
query_classifier = TransformersQueryClassifier(
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||
use_gpu=False,
|
||||
task="text-classification",
|
||||
labels=None,
|
||||
)
|
||||
|
||||
|
||||
def test_transformers_query_classifier_unsupported_task():
|
||||
with pytest.raises(ValueError, match="Task not supported"):
|
||||
query_classifier = TransformersQueryClassifier(
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||
use_gpu=False,
|
||||
task="summarization",
|
||||
labels=["LABEL_1", "LABEL_0"],
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user