haystack/test/nodes/test_query_classifier.py
Stefano Fiorucci 4a63484916
feat: Extend TransformersQueryClassifier: clean version (#2965)
* extend query classifier in one commit

* variable number of outgoing edges

* improve tests

* fix unused import

* lightweight approach

* fix _calculate_outgoing_edges

* remove duplicate label validation

* Remove print
2022-08-09 09:43:33 +02:00

95 lines
3.5 KiB
Python

import pytest
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
@pytest.fixture
def transformers_query_classifier():
return TransformersQueryClassifier(
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
task="text-classification",
labels=["LABEL_1", "LABEL_0"],
)
@pytest.fixture
def zero_shot_transformers_query_classifier():
return TransformersQueryClassifier(
model_name_or_path="typeform/distilbert-base-uncased-mnli",
use_gpu=False,
task="zero-shot-classification",
labels=["happy", "unhappy", "neutral"],
)
def test_transformers_query_classifier(transformers_query_classifier):
output = transformers_query_classifier.run(query="morse code")
assert output == ({}, "output_2")
output = transformers_query_classifier.run(query="How old is John?")
assert output == ({}, "output_1")
def test_transformers_query_classifier_batch(transformers_query_classifier):
queries = ["morse code", "How old is John?"]
output = transformers_query_classifier.run_batch(queries=queries)
assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}}
def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier):
output = zero_shot_transformers_query_classifier.run(query="What's the answer?")
assert output == ({}, "output_3")
output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?")
assert output == ({}, "output_1")
output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??")
assert output == ({}, "output_2")
def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier):
queries = [
"What's the answer?",
"Would you be so kind to tell me the answer?",
"Can you give me the right answer for once??",
]
output = zero_shot_transformers_query_classifier.run_batch(queries=queries)
assert output[0] == {
"output_3": {"queries": ["What's the answer?"]},
"output_1": {"queries": ["Would you be so kind to tell me the answer?"]},
"output_2": {"queries": ["Can you give me the right answer for once??"]},
}
def test_transformers_query_classifier_wrong_labels():
with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"):
query_classifier = TransformersQueryClassifier(
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
task="text-classification",
labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"],
)
def test_transformers_query_classifier_no_labels():
with pytest.raises(ValueError, match="The labels must be provided"):
query_classifier = TransformersQueryClassifier(
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
task="text-classification",
labels=None,
)
def test_transformers_query_classifier_unsupported_task():
with pytest.raises(ValueError, match="Task not supported"):
query_classifier = TransformersQueryClassifier(
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection",
use_gpu=False,
task="summarization",
labels=["LABEL_1", "LABEL_0"],
)