mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 23:48:53 +00:00
Add QueryClassifier incl. baseline models (#1099)
* restructure query classifier code and add s3 based pickles * make model and vectorizer optional in query classifier * update query classifier as per init style * add query classifiers sklearn/hf * update docstrings for query classifiers * add unit test for query classifier * add type patch for sklearn classifier * fix mypy type issue * revert to pure formatting * add query classifiers * resolve conflict * add output names for query classifier * revert output and update docstring queryclassifier * Update docstring for SklearnQueryClassifier * update transformer query classifier docstring * fix typo * change arg names in query classifier classes * add set_config(). rename attributes * fix set_config() Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
600636e77b
commit
545c625a37
@ -5,7 +5,11 @@ import traceback
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional, Dict, Union, Any
|
||||
import pickle
|
||||
import urllib
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
|
||||
|
||||
import networkx as nx
|
||||
import yaml
|
||||
@ -594,6 +598,174 @@ class RootNode:
|
||||
return kwargs, "output_1"
|
||||
|
||||
|
||||
class SklearnQueryClassifier(BaseComponent):
|
||||
"""
|
||||
A node to classify an incoming query into one of two categories using a lightweight sklearn model. Depending on the result, the query flows to a different branch in your pipeline
|
||||
and the further processing can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2` from this node.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|{
|
||||
|pipe = Pipeline()
|
||||
|pipe.add_node(component=SklearnQueryClassifier(), name="QueryClassifier", inputs=["Query"])
|
||||
|pipe.add_node(component=elastic_retriever, name="ElasticRetriever", inputs=["QueryClassifier.output_2"])
|
||||
|pipe.add_node(component=dpr_retriever, name="DPRRetriever", inputs=["QueryClassifier.output_1"])
|
||||
|
||||
|# Keyword queries will use the ElasticRetriever
|
||||
|pipe.run("kubernetes aws")
|
||||
|
||||
|# Semantic queries (questions, statements, sentences ...) will leverage the DPR retriever
|
||||
|pipe.run("How to manage kubernetes on aws")
|
||||
|
||||
```
|
||||
|
||||
Models:
|
||||
|
||||
Pass your own `Sklearn` binary classification model or use one of the following pretrained ones:
|
||||
1) Keywords vs. Questions/Statements (Default)
|
||||
query_classifier="https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier/model.pickle"
|
||||
query_vectorizer="https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier/vectorizer.pickle"
|
||||
output_1 => question/statement
|
||||
output_2 => keyword query
|
||||
[Readme](https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier/readme.txt)
|
||||
|
||||
|
||||
2) Questions vs. Statements
|
||||
`query_classifier`="https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier_statements/model.pickle"`
|
||||
`query_vectorizer`="https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier_statements/vectorizer.pickle"`
|
||||
output_1 => question
|
||||
output_2 => statement
|
||||
[Readme](https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier_statements/readme.txt)
|
||||
|
||||
See also the [tutorial](https://haystack.deepset.ai/docs/latest/tutorial11md) on pipelines.
|
||||
|
||||
"""
|
||||
|
||||
outgoing_edges = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[
|
||||
str, Any
|
||||
] = "https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier/model.pickle",
|
||||
vectorizer_name_or_path: Union[
|
||||
str, Any
|
||||
] = "https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier/vectorizer.pickle"
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: Gradient boosting based binary classifier to classify between keyword vs statement/question
|
||||
queries or statement vs question queries.
|
||||
:param vectorizer_name_or_path: A ngram based Tfidf vectorizer for extracting features from query.
|
||||
"""
|
||||
if (
|
||||
(not isinstance(model_name_or_path, Path))
|
||||
and (not isinstance(model_name_or_path, str))
|
||||
) or (
|
||||
(not isinstance(vectorizer_name_or_path, Path))
|
||||
and (not isinstance(vectorizer_name_or_path, str))
|
||||
):
|
||||
raise TypeError(
|
||||
"query_classifier and query_classifier must either be of type Path or str"
|
||||
)
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(model_name_or_path=model_name_or_path, vectorizer_name_or_path=vectorizer_name_or_path)
|
||||
|
||||
if isinstance(model_name_or_path, Path):
|
||||
file_url = urllib.request.pathname2url(r"{}".format(model_name_or_path))
|
||||
model_name_or_path = f"file:{file_url}"
|
||||
|
||||
if isinstance(vectorizer_name_or_path, Path):
|
||||
file_url = urllib.request.pathname2url(r"{}".format(vectorizer_name_or_path))
|
||||
vectorizer_name_or_path = f"file:{file_url}"
|
||||
|
||||
self.model = pickle.load(urllib.request.urlopen(model_name_or_path))
|
||||
self.vectorizer = pickle.load(urllib.request.urlopen(vectorizer_name_or_path))
|
||||
|
||||
|
||||
def run(self, **kwargs):
|
||||
query_vector = self.vectorizer.transform([kwargs["query"]])
|
||||
|
||||
is_question: bool = self.model.predict(query_vector)[0]
|
||||
if is_question:
|
||||
return (kwargs, "output_1")
|
||||
else:
|
||||
return (kwargs, "output_2")
|
||||
|
||||
|
||||
class TransformersQueryClassifier(BaseComponent):
|
||||
"""
|
||||
A node to classify an incoming query into one of two categories using a (small) BERT transformer model. Depending on the result, the query flows to a different branch in your pipeline
|
||||
and the further processing can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2` from this node.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|{
|
||||
|pipe = Pipeline()
|
||||
|pipe.add_node(component=TransformersQueryClassifier(), name="QueryClassifier", inputs=["Query"])
|
||||
|pipe.add_node(component=elastic_retriever, name="ElasticRetriever", inputs=["QueryClassifier.output_2"])
|
||||
|pipe.add_node(component=dpr_retriever, name="DPRRetriever", inputs=["QueryClassifier.output_1"])
|
||||
|
||||
|# Keyword queries will use the ElasticRetriever
|
||||
|pipe.run("kubernetes aws")
|
||||
|
||||
|# Semantic queries (questions, statements, sentences ...) will leverage the DPR retriever
|
||||
|pipe.run("How to manage kubernetes on aws")
|
||||
|
||||
```
|
||||
|
||||
Models:
|
||||
|
||||
Pass your own `Transformer` binary classification model from file/huggingface or use one of the following pretrained ones hosted on Huggingface:
|
||||
1) Keywords vs. Questions/Statements (Default)
|
||||
model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection"
|
||||
output_1 => question/statement
|
||||
output_2 => keyword query
|
||||
[Readme](https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier/readme.txt)
|
||||
|
||||
|
||||
2) Questions vs. Statements
|
||||
`model_name_or_path`="shahrukhx01/question-vs-statement-classifier"
|
||||
output_1 => question
|
||||
output_2 => statement
|
||||
[Readme](https://ext-models-haystack.s3.eu-central-1.amazonaws.com/gradboost_query_classifier_statements/readme.txt)
|
||||
|
||||
See also the [tutorial](https://haystack.deepset.ai/docs/latest/tutorial11md) on pipelines.
|
||||
"""
|
||||
|
||||
outgoing_edges = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[
|
||||
Path, str
|
||||
] = "shahrukhx01/bert-mini-finetune-question-detection"
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: Transformer based fine tuned mini bert model for query classification
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(model_name_or_path=model_name_or_path)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
self.query_classification_pipeline = TextClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
def run(self, **kwargs):
|
||||
|
||||
is_question: bool = (
|
||||
self.query_classification_pipeline(kwargs["query"])[0]["label"] == "LABEL_1"
|
||||
)
|
||||
|
||||
if is_question:
|
||||
return (kwargs, "output_1")
|
||||
else:
|
||||
return (kwargs, "output_2")
|
||||
|
||||
|
||||
class JoinDocuments(BaseComponent):
|
||||
"""
|
||||
A node to join documents outputted by multiple retriever nodes.
|
||||
@ -655,6 +827,7 @@ class JoinDocuments(BaseComponent):
|
||||
raise Exception(f"Invalid join_mode: {self.join_mode}")
|
||||
|
||||
documents = sorted(document_map.values(), key=lambda d: d.score, reverse=True)
|
||||
|
||||
if self.top_k_join:
|
||||
documents = documents[: self.top_k_join]
|
||||
output = {"query": inputs[0]["query"], "documents": documents, "labels": inputs[0].get("labels", None)}
|
||||
|
||||
@ -3,8 +3,17 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.pipeline import TranslationWrapperPipeline, JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, \
|
||||
DocumentSearchPipeline, RootNode
|
||||
from haystack.pipeline import (
|
||||
TranslationWrapperPipeline,
|
||||
JoinDocuments,
|
||||
ExtractiveQAPipeline,
|
||||
Pipeline,
|
||||
FAQPipeline,
|
||||
DocumentSearchPipeline,
|
||||
RootNode,
|
||||
SklearnQueryClassifier,
|
||||
TransformersQueryClassifier,
|
||||
)
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
|
||||
@ -12,24 +21,36 @@ from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_load_and_save_yaml(document_store_with_docs, tmp_path):
|
||||
# test correct load of indexing pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="indexing_pipeline")
|
||||
pipeline.run(file_path=Path("samples/pdf/sample_pdf_1.pdf"), top_k_retriever=10, top_k_reader=3)
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="indexing_pipeline"
|
||||
)
|
||||
pipeline.run(
|
||||
file_path=Path("samples/pdf/sample_pdf_1.pdf"),
|
||||
top_k_retriever=10,
|
||||
top_k_reader=3,
|
||||
)
|
||||
|
||||
# test correct load of query pipeline from yaml
|
||||
pipeline = Pipeline.load_from_yaml(Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="query_pipeline")
|
||||
prediction = pipeline.run(query="Who made the PDF specification?", top_k_retriever=10, top_k_reader=3)
|
||||
pipeline = Pipeline.load_from_yaml(
|
||||
Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="query_pipeline"
|
||||
)
|
||||
prediction = pipeline.run(
|
||||
query="Who made the PDF specification?", top_k_retriever=10, top_k_reader=3
|
||||
)
|
||||
assert prediction["query"] == "Who made the PDF specification?"
|
||||
assert prediction["answers"][0]["answer"] == "Adobe Systems"
|
||||
|
||||
# test invalid pipeline name
|
||||
with pytest.raises(Exception):
|
||||
Pipeline.load_from_yaml(path=Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="invalid")
|
||||
Pipeline.load_from_yaml(
|
||||
path=Path("samples/pipeline/test_pipeline.yaml"), pipeline_name="invalid"
|
||||
)
|
||||
|
||||
# test config export
|
||||
pipeline.save_to_yaml(tmp_path / "test.yaml")
|
||||
with open(tmp_path/"test.yaml", "r", encoding='utf-8') as stream:
|
||||
with open(tmp_path / "test.yaml", "r", encoding="utf-8") as stream:
|
||||
saved_yaml = stream.read()
|
||||
expected_yaml = '''
|
||||
expected_yaml = """
|
||||
components:
|
||||
- name: ESRetriever
|
||||
params:
|
||||
@ -56,31 +77,43 @@ def test_load_and_save_yaml(document_store_with_docs, tmp_path):
|
||||
name: Reader
|
||||
type: Query
|
||||
version: '0.8'
|
||||
'''
|
||||
assert saved_yaml.replace(" ", "").replace("\n", "") == expected_yaml.replace(" ", "").replace("\n", "")
|
||||
"""
|
||||
assert saved_yaml.replace(" ", "").replace("\n", "") == expected_yaml.replace(
|
||||
" ", ""
|
||||
).replace("\n", "")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever_with_docs, document_store_with_docs", [("elasticsearch", "elasticsearch")], indirect=True
|
||||
"retriever_with_docs, document_store_with_docs",
|
||||
[("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.output_2"])
|
||||
pipeline.add_node(
|
||||
name="Reader", component=retriever_with_docs, inputs=["ES.output_2"]
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"])
|
||||
pipeline.add_node(
|
||||
name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"]
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
|
||||
pipeline.add_node(
|
||||
name="Reader", component=retriever_with_docs, inputs=["InvalidNode"]
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["InvalidNode"])
|
||||
pipeline.add_node(
|
||||
name="ES", component=retriever_with_docs, inputs=["InvalidNode"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -88,14 +121,18 @@ def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers(reader, retriever_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3
|
||||
)
|
||||
assert prediction is not None
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
assert (
|
||||
prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
)
|
||||
|
||||
assert len(prediction["answers"]) == 3
|
||||
|
||||
@ -104,13 +141,18 @@ def test_extractive_qa_answers(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_offsets(reader, retriever_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5)
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5
|
||||
)
|
||||
|
||||
assert prediction["answers"][0]["offset_start"] == 11
|
||||
assert prediction["answers"][0]["offset_end"] == 16
|
||||
start = prediction["answers"][0]["offset_start"]
|
||||
end = prediction["answers"][0]["offset_end"]
|
||||
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||
assert (
|
||||
prediction["answers"][0]["context"][start:end]
|
||||
== prediction["answers"][0]["answer"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -127,16 +169,36 @@ def test_extractive_qa_answers_single_result(reader, retriever_with_docs):
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
|
||||
[
|
||||
("embedding", "memory"),
|
||||
("embedding", "faiss"),
|
||||
("embedding", "milvus"),
|
||||
("embedding", "elasticsearch"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_faq_pipeline(retriever, document_store):
|
||||
documents = [
|
||||
{"text": "How to test module-1?", 'meta': {"source": "wiki1", "answer": "Using tests for module-1"}},
|
||||
{"text": "How to test module-2?", 'meta': {"source": "wiki2", "answer": "Using tests for module-2"}},
|
||||
{"text": "How to test module-3?", 'meta': {"source": "wiki3", "answer": "Using tests for module-3"}},
|
||||
{"text": "How to test module-4?", 'meta': {"source": "wiki4", "answer": "Using tests for module-4"}},
|
||||
{"text": "How to test module-5?", 'meta': {"source": "wiki5", "answer": "Using tests for module-5"}},
|
||||
{
|
||||
"text": "How to test module-1?",
|
||||
"meta": {"source": "wiki1", "answer": "Using tests for module-1"},
|
||||
},
|
||||
{
|
||||
"text": "How to test module-2?",
|
||||
"meta": {"source": "wiki2", "answer": "Using tests for module-2"},
|
||||
},
|
||||
{
|
||||
"text": "How to test module-3?",
|
||||
"meta": {"source": "wiki3", "answer": "Using tests for module-3"},
|
||||
},
|
||||
{
|
||||
"text": "How to test module-4?",
|
||||
"meta": {"source": "wiki4", "answer": "Using tests for module-4"},
|
||||
},
|
||||
{
|
||||
"text": "How to test module-5?",
|
||||
"meta": {"source": "wiki5", "answer": "Using tests for module-5"},
|
||||
},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
@ -150,23 +212,30 @@ def test_faq_pipeline(retriever, document_store):
|
||||
assert output["answers"][0]["answer"].startswith("Using tests")
|
||||
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5)
|
||||
output = pipeline.run(
|
||||
query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5
|
||||
)
|
||||
assert len(output["answers"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
|
||||
[
|
||||
("embedding", "memory"),
|
||||
("embedding", "faiss"),
|
||||
("embedding", "milvus"),
|
||||
("embedding", "elasticsearch"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_document_search_pipeline(retriever, document_store):
|
||||
documents = [
|
||||
{"text": "Sample text for document-1", 'meta': {"source": "wiki1"}},
|
||||
{"text": "Sample text for document-2", 'meta': {"source": "wiki2"}},
|
||||
{"text": "Sample text for document-3", 'meta': {"source": "wiki3"}},
|
||||
{"text": "Sample text for document-4", 'meta': {"source": "wiki4"}},
|
||||
{"text": "Sample text for document-5", 'meta': {"source": "wiki5"}},
|
||||
{"text": "Sample text for document-1", "meta": {"source": "wiki1"}},
|
||||
{"text": "Sample text for document-2", "meta": {"source": "wiki2"}},
|
||||
{"text": "Sample text for document-3", "meta": {"source": "wiki3"}},
|
||||
{"text": "Sample text for document-4", "meta": {"source": "wiki4"}},
|
||||
{"text": "Sample text for document-5", "meta": {"source": "wiki5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
@ -174,32 +243,40 @@ def test_document_search_pipeline(retriever, document_store):
|
||||
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever)
|
||||
output = pipeline.run(query="How to test this?", top_k_retriever=4)
|
||||
assert len(output.get('documents', [])) == 4
|
||||
assert len(output.get("documents", [])) == 4
|
||||
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
output = pipeline.run(query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5)
|
||||
output = pipeline.run(
|
||||
query="How to test this?", filters={"source": ["wiki2"]}, top_k_retriever=5
|
||||
)
|
||||
assert len(output["documents"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers_with_translator(reader, retriever_with_docs, en_to_de_translator, de_to_en_translator):
|
||||
def test_extractive_qa_answers_with_translator(
|
||||
reader, retriever_with_docs, en_to_de_translator, de_to_en_translator
|
||||
):
|
||||
base_pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline
|
||||
pipeline=base_pipeline,
|
||||
)
|
||||
|
||||
prediction = pipeline.run(query="Wer lebt in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||
prediction = pipeline.run(
|
||||
query="Wer lebt in Berlin?", top_k_retriever=10, top_k_reader=3
|
||||
)
|
||||
assert prediction is not None
|
||||
assert prediction["query"] == "Wer lebt in Berlin?"
|
||||
assert "Carla" in prediction["answers"][0]["answer"]
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
assert (
|
||||
prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@ -283,7 +360,9 @@ def test_parallel_paths_in_pipeline_graph():
|
||||
|
||||
class JoinNode(RootNode):
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = kwargs["inputs"][0]["output"] + kwargs["inputs"][1]["output"]
|
||||
kwargs["output"] = (
|
||||
kwargs["inputs"][0]["output"] + kwargs["inputs"][1]["output"]
|
||||
)
|
||||
return kwargs, "output_1"
|
||||
|
||||
pipeline = Pipeline()
|
||||
@ -309,18 +388,21 @@ def test_parallel_paths_in_pipeline_graph():
|
||||
def test_parallel_paths_in_pipeline_graph_with_branching():
|
||||
class AWithOutput1(RootNode):
|
||||
outgoing_edges = 2
|
||||
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "A"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class AWithOutput2(RootNode):
|
||||
outgoing_edges = 2
|
||||
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "A"
|
||||
return kwargs, "output_2"
|
||||
|
||||
class AWithOutputAll(RootNode):
|
||||
outgoing_edges = 2
|
||||
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "A"
|
||||
return kwargs, "output_all"
|
||||
@ -350,7 +432,7 @@ def test_parallel_paths_in_pipeline_graph_with_branching():
|
||||
if kwargs.get("inputs"):
|
||||
kwargs["output"] = ""
|
||||
for input_dict in kwargs["inputs"]:
|
||||
kwargs["output"] += (input_dict["output"])
|
||||
kwargs["output"] += input_dict["output"]
|
||||
return kwargs, "output_1"
|
||||
|
||||
pipeline = Pipeline()
|
||||
@ -384,4 +466,61 @@ def test_parallel_paths_in_pipeline_graph_with_branching():
|
||||
assert output["output"] == "ACABEABD"
|
||||
|
||||
|
||||
def test_query_keyword_statement_classifier():
|
||||
class KeywordOutput(RootNode):
|
||||
outgoing_edges = 2
|
||||
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "keyword"
|
||||
return kwargs, "output_1"
|
||||
|
||||
class QuestionOutput(RootNode):
|
||||
outgoing_edges = 2
|
||||
|
||||
def run(self, **kwargs):
|
||||
kwargs["output"] = "question"
|
||||
return kwargs, "output_2"
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(
|
||||
name="SkQueryKeywordQuestionClassifier",
|
||||
component=SklearnQueryClassifier(),
|
||||
inputs=["Query"],
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="KeywordNode",
|
||||
component=KeywordOutput(),
|
||||
inputs=["SkQueryKeywordQuestionClassifier.output_2"],
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="QuestionNode",
|
||||
component=QuestionOutput(),
|
||||
inputs=["SkQueryKeywordQuestionClassifier.output_1"],
|
||||
)
|
||||
output = pipeline.run(query="morse code")
|
||||
assert output["output"] == "keyword"
|
||||
|
||||
output = pipeline.run(query="How old is John?")
|
||||
assert output["output"] == "question"
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(
|
||||
name="TfQueryKeywordQuestionClassifier",
|
||||
component=TransformersQueryClassifier(),
|
||||
inputs=["Query"],
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="KeywordNode",
|
||||
component=KeywordOutput(),
|
||||
inputs=["TfQueryKeywordQuestionClassifier.output_2"],
|
||||
)
|
||||
pipeline.add_node(
|
||||
name="QuestionNode",
|
||||
component=QuestionOutput(),
|
||||
inputs=["TfQueryKeywordQuestionClassifier.output_1"],
|
||||
)
|
||||
output = pipeline.run(query="morse code")
|
||||
assert output["output"] == "keyword"
|
||||
|
||||
output = pipeline.run(query="How old is John?")
|
||||
assert output["output"] == "question"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user