mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-09 06:13:43 +00:00
feat: Extend TransformersQueryClassifier: clean version (#2965)
* extend query classifier in one commit * variable number of outgoing edges * improve tests * fix unused import * lightweight approach * fix _calculate_outgoing_edges * remove duplicate label validation * Remove print
This commit is contained in:
parent
c91316e862
commit
4a63484916
@ -96,10 +96,11 @@ queries or statement vs question queries.
|
|||||||
class TransformersQueryClassifier(BaseQueryClassifier)
|
class TransformersQueryClassifier(BaseQueryClassifier)
|
||||||
```
|
```
|
||||||
|
|
||||||
A node to classify an incoming query into one of two categories using a (small) BERT transformer model.
|
A node to classify an incoming query into categories using a transformer model.
|
||||||
Depending on the result, the query flows to a different branch in your pipeline and the further processing
|
Depending on the result, the query flows to a different branch in your pipeline and the further processing
|
||||||
can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2`
|
can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
|
||||||
from this node.
|
from this node.
|
||||||
|
This node also supports zero-shot-classification.
|
||||||
|
|
||||||
**Example**:
|
**Example**:
|
||||||
|
|
||||||
@ -120,7 +121,7 @@ from this node.
|
|||||||
|
|
||||||
Models:
|
Models:
|
||||||
|
|
||||||
Pass your own `Transformer` binary classification model from file/huggingface or use one of the following
|
Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following
|
||||||
pretrained ones hosted on Huggingface:
|
pretrained ones hosted on Huggingface:
|
||||||
1) Keywords vs. Questions/Statements (Default)
|
1) Keywords vs. Questions/Statements (Default)
|
||||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
||||||
@ -143,13 +144,20 @@ from this node.
|
|||||||
#### TransformersQueryClassifier.\_\_init\_\_
|
#### TransformersQueryClassifier.\_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", use_gpu: bool = True, batch_size: int = 16, progress_bar: bool = True)
|
def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = DEFAULT_LABELS, batch_size: int = 16, progress_bar: bool = True)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Arguments**:
|
**Arguments**:
|
||||||
|
|
||||||
- `model_name_or_path`: Transformer based fine tuned mini bert model for query classification
|
- `model_name_or_path`: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'.
|
||||||
|
See [Hugging Face models](https://huggingface.co/models) for a full list of available models.
|
||||||
|
- `model_version`: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash.
|
||||||
|
- `tokenizer`: The name of the tokenizer (usually the same as model).
|
||||||
- `use_gpu`: Whether to use GPU (if available).
|
- `use_gpu`: Whether to use GPU (if available).
|
||||||
- `batch_size`: Batch size for inference.
|
- `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
|
||||||
|
- `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
|
||||||
|
the second label to output_2, and so on. The labels must match the model labels; only the order can differ.
|
||||||
|
If the task is 'zero-shot-classification', these are the candidate labels.
|
||||||
|
- `batch_size`: The number of queries to be processed at a time.
|
||||||
- `progress_bar`: Whether to show a progress bar.
|
- `progress_bar`: Whether to show a progress bar.
|
||||||
|
|
||||||
|
|||||||
@ -4801,11 +4801,35 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"model_version": {
|
||||||
|
"title": "Model Version",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"tokenizer": {
|
||||||
|
"title": "Tokenizer",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"use_gpu": {
|
"use_gpu": {
|
||||||
"title": "Use Gpu",
|
"title": "Use Gpu",
|
||||||
"default": true,
|
"default": true,
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"task": {
|
||||||
|
"title": "Task",
|
||||||
|
"default": "text-classification",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"labels": {
|
||||||
|
"title": "Labels",
|
||||||
|
"default": [
|
||||||
|
"LABEL_1",
|
||||||
|
"LABEL_0"
|
||||||
|
],
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
"batch_size": {
|
"batch_size": {
|
||||||
"title": "Batch Size",
|
"title": "Batch Size",
|
||||||
"default": 16,
|
"default": 16,
|
||||||
|
|||||||
@ -1,22 +1,27 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List, Optional, Dict
|
from typing import Union, List, Optional, Dict, Any
|
||||||
|
|
||||||
|
from transformers import pipeline
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
|
|
||||||
|
# from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
|
||||||
from haystack.nodes.query_classifier.base import BaseQueryClassifier
|
from haystack.nodes.query_classifier.base import BaseQueryClassifier
|
||||||
from haystack.modeling.utils import initialize_device_settings
|
from haystack.modeling.utils import initialize_device_settings
|
||||||
from haystack.utils.torch_utils import ListDataset
|
from haystack.utils.torch_utils import ListDataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_LABELS = ["LABEL_1", "LABEL_0"]
|
||||||
|
|
||||||
|
|
||||||
class TransformersQueryClassifier(BaseQueryClassifier):
|
class TransformersQueryClassifier(BaseQueryClassifier):
|
||||||
"""
|
"""
|
||||||
A node to classify an incoming query into one of two categories using a (small) BERT transformer model.
|
A node to classify an incoming query into categories using a transformer model.
|
||||||
Depending on the result, the query flows to a different branch in your pipeline and the further processing
|
Depending on the result, the query flows to a different branch in your pipeline and the further processing
|
||||||
can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2`
|
can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n`
|
||||||
from this node.
|
from this node.
|
||||||
|
This node also supports zero-shot-classification.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@ -36,7 +41,7 @@ class TransformersQueryClassifier(BaseQueryClassifier):
|
|||||||
|
|
||||||
Models:
|
Models:
|
||||||
|
|
||||||
Pass your own `Transformer` binary classification model from file/huggingface or use one of the following
|
Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following
|
||||||
pretrained ones hosted on Huggingface:
|
pretrained ones hosted on Huggingface:
|
||||||
1) Keywords vs. Questions/Statements (Default)
|
1) Keywords vs. Questions/Statements (Default)
|
||||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
||||||
@ -58,56 +63,98 @@ class TransformersQueryClassifier(BaseQueryClassifier):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection",
|
model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection",
|
||||||
|
model_version: Optional[str] = None,
|
||||||
|
tokenizer: Optional[str] = None,
|
||||||
use_gpu: bool = True,
|
use_gpu: bool = True,
|
||||||
|
task: str = "text-classification",
|
||||||
|
labels: List[str] = DEFAULT_LABELS,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param model_name_or_path: Transformer based fine tuned mini bert model for query classification
|
:param model_name_or_path: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'.
|
||||||
|
See [Hugging Face models](https://huggingface.co/models) for a full list of available models.
|
||||||
|
:param model_version: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash.
|
||||||
|
:param tokenizer: The name of the tokenizer (usually the same as model).
|
||||||
:param use_gpu: Whether to use GPU (if available).
|
:param use_gpu: Whether to use GPU (if available).
|
||||||
:param batch_size: Batch size for inference.
|
:param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'.
|
||||||
|
:param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1,
|
||||||
|
the second label to output_2, and so on. The labels must match the model labels; only the order can differ.
|
||||||
|
If the task is 'zero-shot-classification', these are the candidate labels.
|
||||||
|
:param batch_size: The number of queries to be processed at a time.
|
||||||
:param progress_bar: Whether to show a progress bar.
|
:param progress_bar: Whether to show a progress bar.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
|
||||||
|
device = 0 if devices[0].type == "cuda" else -1
|
||||||
|
|
||||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu)
|
self.model = pipeline(
|
||||||
|
task=task, model=model_name_or_path, tokenizer=tokenizer, device=device, revision=model_version
|
||||||
|
)
|
||||||
|
|
||||||
|
self.labels = labels
|
||||||
|
if task == "text-classification":
|
||||||
|
labels_from_model = [label for label in self.model.model.config.id2label.values()]
|
||||||
|
if set(labels) != set(labels_from_model):
|
||||||
|
raise ValueError(
|
||||||
|
f"For text-classification, the provided labels must match the model labels; only the order can differ.\n"
|
||||||
|
f"Provided labels: {labels}\n"
|
||||||
|
f"Model labels: {labels_from_model}"
|
||||||
|
)
|
||||||
|
if task not in ["text-classification", "zero-shot-classification"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task not supported: {task}.\n"
|
||||||
|
f"Possible task values are: 'text-classification' or 'zero-shot-classification'"
|
||||||
|
)
|
||||||
|
self.task = task
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
device = 0 if self.devices[0].type == "cuda" else -1
|
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
|
@classmethod
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int:
|
||||||
|
labels = component_params.get("labels", DEFAULT_LABELS)
|
||||||
|
if labels is None or len(labels) == 0:
|
||||||
|
raise ValueError("The labels must be provided")
|
||||||
|
return len(labels)
|
||||||
|
|
||||||
self.query_classification_pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=device)
|
def _get_edge_number_from_label(self, label):
|
||||||
|
return self.labels.index(label) + 1
|
||||||
|
|
||||||
def run(self, query):
|
def run(self, query: str): # type: ignore
|
||||||
is_question: bool = self.query_classification_pipeline(query)[0]["label"] == "LABEL_1"
|
if self.task == "zero-shot-classification":
|
||||||
|
prediction = self.model([query], candidate_labels=self.labels, truncation=True)
|
||||||
if is_question:
|
label = prediction[0]["labels"][0]
|
||||||
return {}, "output_1"
|
elif self.task == "text-classification":
|
||||||
else:
|
prediction = self.model([query], truncation=True)
|
||||||
return {}, "output_2"
|
label = prediction[0]["label"]
|
||||||
|
return {}, f"output_{self._get_edge_number_from_label(label)}"
|
||||||
|
|
||||||
def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore
|
def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore
|
||||||
if batch_size is None:
|
|
||||||
batch_size = self.batch_size
|
|
||||||
|
|
||||||
split: Dict[str, Dict[str, List]] = {"output_1": {"queries": []}, "output_2": {"queries": []}}
|
|
||||||
|
|
||||||
# HF pb hack https://discuss.huggingface.co/t/progress-bar-for-hf-pipelines/20498/2
|
# HF pb hack https://discuss.huggingface.co/t/progress-bar-for-hf-pipelines/20498/2
|
||||||
queries_dataset = ListDataset(queries)
|
queries_dataset = ListDataset(queries)
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
all_predictions = []
|
all_predictions = []
|
||||||
|
if self.task == "zero-shot-classification":
|
||||||
for predictions in tqdm(
|
for predictions in tqdm(
|
||||||
self.query_classification_pipeline(queries_dataset, batch_size=batch_size),
|
self.model(queries_dataset, candidate_labels=self.labels, truncation=True, batch_size=batch_size),
|
||||||
disable=not self.progress_bar,
|
disable=not self.progress_bar,
|
||||||
desc="Classifying queries",
|
desc="Classifying queries",
|
||||||
):
|
):
|
||||||
all_predictions.extend(predictions)
|
all_predictions.extend([predictions])
|
||||||
|
elif self.task == "text-classification":
|
||||||
|
for predictions in tqdm(
|
||||||
|
self.model(queries_dataset, truncation=True, batch_size=batch_size),
|
||||||
|
disable=not self.progress_bar,
|
||||||
|
desc="Classifying queries",
|
||||||
|
):
|
||||||
|
all_predictions.extend([predictions])
|
||||||
|
results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore
|
||||||
|
for query, prediction in zip(queries, all_predictions):
|
||||||
|
if self.task == "zero-shot-classification":
|
||||||
|
label = prediction["labels"][0]
|
||||||
|
elif self.task == "text-classification":
|
||||||
|
label = prediction["label"]
|
||||||
|
results[f"output_{self._get_edge_number_from_label(label)}"]["queries"].append(query)
|
||||||
|
|
||||||
for query, pred in zip(queries, all_predictions):
|
return results, "split"
|
||||||
if pred["label"] == "LABEL_1":
|
|
||||||
split["output_1"]["queries"].append(query)
|
|
||||||
else:
|
|
||||||
split["output_2"]["queries"].append(query)
|
|
||||||
|
|
||||||
return split, "split"
|
|
||||||
|
|||||||
94
test/nodes/test_query_classifier.py
Normal file
94
test/nodes/test_query_classifier.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import pytest
|
||||||
|
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def transformers_query_classifier():
|
||||||
|
return TransformersQueryClassifier(
|
||||||
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||||
|
use_gpu=False,
|
||||||
|
task="text-classification",
|
||||||
|
labels=["LABEL_1", "LABEL_0"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zero_shot_transformers_query_classifier():
|
||||||
|
return TransformersQueryClassifier(
|
||||||
|
model_name_or_path="typeform/distilbert-base-uncased-mnli",
|
||||||
|
use_gpu=False,
|
||||||
|
task="zero-shot-classification",
|
||||||
|
labels=["happy", "unhappy", "neutral"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_query_classifier(transformers_query_classifier):
|
||||||
|
output = transformers_query_classifier.run(query="morse code")
|
||||||
|
assert output == ({}, "output_2")
|
||||||
|
|
||||||
|
output = transformers_query_classifier.run(query="How old is John?")
|
||||||
|
assert output == ({}, "output_1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_query_classifier_batch(transformers_query_classifier):
|
||||||
|
queries = ["morse code", "How old is John?"]
|
||||||
|
output = transformers_query_classifier.run_batch(queries=queries)
|
||||||
|
|
||||||
|
assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier):
|
||||||
|
output = zero_shot_transformers_query_classifier.run(query="What's the answer?")
|
||||||
|
assert output == ({}, "output_3")
|
||||||
|
|
||||||
|
output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?")
|
||||||
|
assert output == ({}, "output_1")
|
||||||
|
|
||||||
|
output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??")
|
||||||
|
assert output == ({}, "output_2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier):
|
||||||
|
queries = [
|
||||||
|
"What's the answer?",
|
||||||
|
"Would you be so kind to tell me the answer?",
|
||||||
|
"Can you give me the right answer for once??",
|
||||||
|
]
|
||||||
|
|
||||||
|
output = zero_shot_transformers_query_classifier.run_batch(queries=queries)
|
||||||
|
|
||||||
|
assert output[0] == {
|
||||||
|
"output_3": {"queries": ["What's the answer?"]},
|
||||||
|
"output_1": {"queries": ["Would you be so kind to tell me the answer?"]},
|
||||||
|
"output_2": {"queries": ["Can you give me the right answer for once??"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_query_classifier_wrong_labels():
|
||||||
|
with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"):
|
||||||
|
query_classifier = TransformersQueryClassifier(
|
||||||
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||||
|
use_gpu=False,
|
||||||
|
task="text-classification",
|
||||||
|
labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_query_classifier_no_labels():
|
||||||
|
with pytest.raises(ValueError, match="The labels must be provided"):
|
||||||
|
query_classifier = TransformersQueryClassifier(
|
||||||
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||||
|
use_gpu=False,
|
||||||
|
task="text-classification",
|
||||||
|
labels=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_query_classifier_unsupported_task():
|
||||||
|
with pytest.raises(ValueError, match="Task not supported"):
|
||||||
|
query_classifier = TransformersQueryClassifier(
|
||||||
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
||||||
|
use_gpu=False,
|
||||||
|
task="summarization",
|
||||||
|
labels=["LABEL_1", "LABEL_0"],
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user