mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-26 02:10:41 +00:00

* extend query classifier in one commit * variable number of outgoing edges * improve tests * fix unused import * lightweight approach * fix _calculate_outgoing_edges * remove duplicate label validation * Remove print
95 lines
3.5 KiB
Python
95 lines
3.5 KiB
Python
import pytest
|
|
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
|
|
|
|
|
|
@pytest.fixture
|
|
def transformers_query_classifier():
|
|
return TransformersQueryClassifier(
|
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
|
use_gpu=False,
|
|
task="text-classification",
|
|
labels=["LABEL_1", "LABEL_0"],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def zero_shot_transformers_query_classifier():
|
|
return TransformersQueryClassifier(
|
|
model_name_or_path="typeform/distilbert-base-uncased-mnli",
|
|
use_gpu=False,
|
|
task="zero-shot-classification",
|
|
labels=["happy", "unhappy", "neutral"],
|
|
)
|
|
|
|
|
|
def test_transformers_query_classifier(transformers_query_classifier):
|
|
output = transformers_query_classifier.run(query="morse code")
|
|
assert output == ({}, "output_2")
|
|
|
|
output = transformers_query_classifier.run(query="How old is John?")
|
|
assert output == ({}, "output_1")
|
|
|
|
|
|
def test_transformers_query_classifier_batch(transformers_query_classifier):
|
|
queries = ["morse code", "How old is John?"]
|
|
output = transformers_query_classifier.run_batch(queries=queries)
|
|
|
|
assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}}
|
|
|
|
|
|
def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier):
|
|
output = zero_shot_transformers_query_classifier.run(query="What's the answer?")
|
|
assert output == ({}, "output_3")
|
|
|
|
output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?")
|
|
assert output == ({}, "output_1")
|
|
|
|
output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??")
|
|
assert output == ({}, "output_2")
|
|
|
|
|
|
def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier):
|
|
queries = [
|
|
"What's the answer?",
|
|
"Would you be so kind to tell me the answer?",
|
|
"Can you give me the right answer for once??",
|
|
]
|
|
|
|
output = zero_shot_transformers_query_classifier.run_batch(queries=queries)
|
|
|
|
assert output[0] == {
|
|
"output_3": {"queries": ["What's the answer?"]},
|
|
"output_1": {"queries": ["Would you be so kind to tell me the answer?"]},
|
|
"output_2": {"queries": ["Can you give me the right answer for once??"]},
|
|
}
|
|
|
|
|
|
def test_transformers_query_classifier_wrong_labels():
|
|
with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"):
|
|
query_classifier = TransformersQueryClassifier(
|
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
|
use_gpu=False,
|
|
task="text-classification",
|
|
labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"],
|
|
)
|
|
|
|
|
|
def test_transformers_query_classifier_no_labels():
|
|
with pytest.raises(ValueError, match="The labels must be provided"):
|
|
query_classifier = TransformersQueryClassifier(
|
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
|
use_gpu=False,
|
|
task="text-classification",
|
|
labels=None,
|
|
)
|
|
|
|
|
|
def test_transformers_query_classifier_unsupported_task():
|
|
with pytest.raises(ValueError, match="Task not supported"):
|
|
query_classifier = TransformersQueryClassifier(
|
|
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
|
|
use_gpu=False,
|
|
task="summarization",
|
|
labels=["LABEL_1", "LABEL_0"],
|
|
)
|