mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
Adding Translator (standalone component & wrapper for pipelines) (#782)
* Adding translator with many generic input parameter support * Making dict_key as generic * Fixing mypy issue * Adding pipeline and using opus models * Add latest docstring and tutorial changes * Adding test cases for end-to-end translation for generator, summerizer etc * raise error join and merge nodes * Fix test failure * add docstrings. add usage documentation. rm skip_special_tokens param * Add latest docstring and tutorial changes * fix code snippets in md * Adding few extra configuration parameters and fixing tests * Fixingmypy issue and updating usage document * fix for mypy issue in pipeline.py * reverting renaming of pytest_collection_modifyitems method * Addressing review comments * setting skip_special_tokens to True * removing model_max_length argument as None type is not supported to many models * Removing padding parameter. Better to leave it as default otherwise it cause tensor size miss match error. If this option required by used then it can be added later. Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
4059805d89
commit
5bd94ac5f7
@ -5,7 +5,7 @@
|
||||
## Pipeline Objects
|
||||
|
||||
```python
|
||||
class Pipeline()
|
||||
class Pipeline(ABC)
|
||||
```
|
||||
|
||||
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
|
||||
@ -131,7 +131,7 @@ variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
|
||||
## BaseStandardPipeline Objects
|
||||
|
||||
```python
|
||||
class BaseStandardPipeline()
|
||||
class BaseStandardPipeline(ABC)
|
||||
```
|
||||
|
||||
<a name="pipeline.BaseStandardPipeline.add_node"></a>
|
||||
@ -316,6 +316,32 @@ Initialize a Pipeline for finding similar FAQs using semantic document search.
|
||||
|
||||
- `retriever`: Retriever instance
|
||||
|
||||
<a name="pipeline.TranslationWrapperPipeline"></a>
|
||||
## TranslationWrapperPipeline Objects
|
||||
|
||||
```python
|
||||
class TranslationWrapperPipeline(BaseStandardPipeline)
|
||||
```
|
||||
|
||||
Takes an existing search pipeline and adds one "input translation node" after the Query and one
|
||||
"output translation" node just before returning the results
|
||||
|
||||
<a name="pipeline.TranslationWrapperPipeline.__init__"></a>
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(input_translator: BaseTranslator, output_translator: BaseTranslator, pipeline: BaseStandardPipeline)
|
||||
```
|
||||
|
||||
Wrap a given `pipeline` with the `input_translator` and `output_translator`.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `input_translator`: A Translator node that shall translate the input query from language A to B
|
||||
- `output_translator`: A Translator node that shall translate the pipeline results from language B to A
|
||||
- `pipeline`: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap".
|
||||
Note that pipelines with split or merge nodes are currently not supported.
|
||||
|
||||
<a name="pipeline.JoinDocuments"></a>
|
||||
## JoinDocuments Objects
|
||||
|
||||
|
||||
58
docs/_src/usage/usage/translator.md
Normal file
58
docs/_src/usage/usage/translator.md
Normal file
@ -0,0 +1,58 @@
|
||||
<!---
|
||||
title: "Translator"
|
||||
metaTitle: "Translator"
|
||||
metaDescription: ""
|
||||
slug: "/docs/translator"
|
||||
date: "2021-02-10"
|
||||
id: "translatormd"
|
||||
--->
|
||||
|
||||
# Translator
|
||||
|
||||
Texts come in different languages. This is not different for search and there are plenty of options to deal with it.
|
||||
One of them is actually to translate the incoming query, the documents or the search results.
|
||||
|
||||
Let's imagine you have an English corpus of technical docs, but the mother tongue of many of your users is French.
|
||||
You can use a Translator node in your pipeline to
|
||||
1. Translate the incoming query from French to English
|
||||
2. Search in your English corpus for the right document / answer
|
||||
3. Translate the results back from English to French
|
||||
|
||||
<div class="recommendation">
|
||||
|
||||
**Example (Stand-alone Translator)**
|
||||
|
||||
You can use the Translator component directly to translate your query or document(s):
|
||||
```python
|
||||
DOCS = [
|
||||
Document(
|
||||
text="""Heinz von Foerster was an Austrian American scientist
|
||||
combining physics and philosophy, and widely attributed
|
||||
as the originator of Second-order cybernetics."""
|
||||
)
|
||||
]
|
||||
translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr")
|
||||
res = translator.translate(documents=DOCS, query=None)
|
||||
```
|
||||
|
||||
**Example (Wrapping another Pipeline)**
|
||||
|
||||
You can also wrap one of your existing pipelines and "add" the translation nodes at the beginning and at the end of your pipeline.
|
||||
For example, lets translate the incoming query to from French to English, then do our document retrieval and then translate the results back from English to French:
|
||||
|
||||
```python
|
||||
from haystack.pipeline import TranslationWrapperPipeline, DocumentSearchPipeline
|
||||
from haystack.translator import TransformersTranslator
|
||||
|
||||
pipeline = DocumentSearchPipeline(retriever=my_dpr_retriever)
|
||||
|
||||
in_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-fr-en")
|
||||
out_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr")
|
||||
|
||||
pipeline_with_translation = TranslationWrapperPipeline(input_translator=in_translator,
|
||||
output_translator=out_translator,
|
||||
pipeline=pipeline)
|
||||
```
|
||||
|
||||
|
||||
</div>
|
||||
@ -1,3 +1,4 @@
|
||||
from abc import ABC
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@ -13,9 +14,10 @@ from haystack.generator.base import BaseGenerator
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.summarizer.base import BaseSummarizer
|
||||
from haystack.translator.base import BaseTranslator
|
||||
|
||||
|
||||
class Pipeline:
|
||||
class Pipeline(ABC):
|
||||
"""
|
||||
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
|
||||
|
||||
@ -45,7 +47,7 @@ class Pipeline:
|
||||
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
|
||||
must be specified explicitly as "QueryClassifier.output_2".
|
||||
"""
|
||||
self.graph.add_node(name, component=component)
|
||||
self.graph.add_node(name, component=component, inputs=inputs)
|
||||
|
||||
for i in inputs:
|
||||
if "." in i:
|
||||
@ -93,7 +95,7 @@ class Pipeline:
|
||||
while has_next_node:
|
||||
output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict)
|
||||
input_dict = output_dict
|
||||
next_nodes = self._get_next_nodes(current_node_id, stream_id)
|
||||
next_nodes = self.get_next_nodes(current_node_id, stream_id)
|
||||
|
||||
if len(next_nodes) > 1:
|
||||
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
|
||||
@ -114,7 +116,7 @@ class Pipeline:
|
||||
|
||||
return output_dict
|
||||
|
||||
def _get_next_nodes(self, node_id: str, stream_id: str):
|
||||
def get_next_nodes(self, node_id: str, stream_id: str):
|
||||
current_node_edges = self.graph.edges(node_id, data=True)
|
||||
next_nodes = [
|
||||
next_node
|
||||
@ -259,7 +261,7 @@ class Pipeline:
|
||||
definition["params"][param_name] = value
|
||||
|
||||
|
||||
class BaseStandardPipeline:
|
||||
class BaseStandardPipeline(ABC):
|
||||
pipeline: Pipeline
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
@ -451,6 +453,52 @@ class FAQPipeline(BaseStandardPipeline):
|
||||
return results
|
||||
|
||||
|
||||
class TranslationWrapperPipeline(BaseStandardPipeline):
|
||||
|
||||
"""
|
||||
Takes an existing search pipeline and adds one "input translation node" after the Query and one
|
||||
"output translation" node just before returning the results
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_translator: BaseTranslator,
|
||||
output_translator: BaseTranslator,
|
||||
pipeline: BaseStandardPipeline
|
||||
):
|
||||
"""
|
||||
Wrap a given `pipeline` with the `input_translator` and `output_translator`.
|
||||
|
||||
:param input_translator: A Translator node that shall translate the input query from language A to B
|
||||
:param output_translator: A Translator node that shall translate the pipeline results from language B to A
|
||||
:param pipeline: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap".
|
||||
Note that pipelines with split or merge nodes are currently not supported.
|
||||
"""
|
||||
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])
|
||||
|
||||
graph = pipeline.pipeline.graph
|
||||
previous_node_name = ["InputTranslator"]
|
||||
# Traverse in BFS
|
||||
for node in graph.nodes:
|
||||
if node == "Query":
|
||||
continue
|
||||
|
||||
# TODO: Do not work properly for Join Node and Answer format
|
||||
if graph.nodes[node]["inputs"] and len(graph.nodes[node]["inputs"]) > 1:
|
||||
raise AttributeError("Split and merge nodes are not supported currently")
|
||||
|
||||
self.pipeline.add_node(name=node, component=graph.nodes[node]["component"], inputs=previous_node_name)
|
||||
previous_node_name = [node]
|
||||
|
||||
self.pipeline.add_node(component=output_translator, name="OutputTranslator", inputs=previous_node_name)
|
||||
|
||||
def run(self, **kwargs):
|
||||
output = self.pipeline.run(**kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class QueryNode:
|
||||
outgoing_edges = 1
|
||||
|
||||
|
||||
@ -6,8 +6,6 @@ from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.document_store.memory import InMemoryDocumentStore
|
||||
from haystack import Document
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack import Document
|
||||
from haystack.summarizer.base import BaseSummarizer
|
||||
|
||||
1
haystack/translator/__init__.py
Normal file
1
haystack/translator/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.translator.transformers import TransformersTranslator
|
||||
58
haystack/translator/base.py
Normal file
58
haystack/translator/base.py
Normal file
@ -0,0 +1,58 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from haystack import Document
|
||||
|
||||
|
||||
class BaseTranslator(ABC):
|
||||
"""
|
||||
Abstract class for a Translator component that translates either a query or a doc from language A to language B.
|
||||
"""
|
||||
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def translate(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
|
||||
dict_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Translate the passed query or a list of documents from language A to B.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
|
||||
answers: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
dict_key: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Method that gets executed when this class is used as a Node in a Haystack Pipeline"""
|
||||
|
||||
results: Dict = {
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# This will cover input query stage
|
||||
if query:
|
||||
results["query"] = self.translate(query=query)
|
||||
# This will cover retriever and summarizer
|
||||
if documents:
|
||||
dict_key = dict_key or "text"
|
||||
results["documents"] = self.translate(documents=documents, dict_key=dict_key)
|
||||
|
||||
if answers:
|
||||
dict_key = dict_key or "answer"
|
||||
if isinstance(answers, Mapping):
|
||||
# This will cover reader
|
||||
results["answers"] = self.translate(documents=answers["answers"], dict_key=dict_key)
|
||||
else:
|
||||
# This will cover generator
|
||||
results["answers"] = self.translate(documents=answers, dict_key=dict_key)
|
||||
|
||||
return results, "output_1"
|
||||
127
haystack/translator/transformers.py
Normal file
127
haystack/translator/transformers.py
Normal file
@ -0,0 +1,127 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from haystack import Document
|
||||
from haystack.translator.base import BaseTranslator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformersTranslator(BaseTranslator):
|
||||
"""
|
||||
Translator component based on Seq2Seq models from Huggingface's transformers library.
|
||||
Exemplary use cases:
|
||||
- Translate a query from Language A to B (e.g. if you only have good models + documents in language B)
|
||||
- Translate a document from Language A to B (e.g. if you want to return results in the native language of the user)
|
||||
|
||||
We currently recommend using OPUS models (see __init__() for details)
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
| DOCS = [
|
||||
| Document(text="Heinz von Foerster was an Austrian American scientist combining physics and philosophy,
|
||||
| and widely attributed as the originator of Second-order cybernetics.")
|
||||
| ]
|
||||
| translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-de")
|
||||
| res = translator.translate(documents=DOCS, query=None)
|
||||
```
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
clean_up_tokenization_spaces: Optional[bool] = True
|
||||
):
|
||||
""" Initialize the translator with a model that fits your targeted languages. While we support all seq2seq
|
||||
models from Hugging Face's model hub, we recommend using the OPUS models from Helsiniki NLP. They provide plenty
|
||||
of different models, usually one model per language pair and translation direction.
|
||||
They have a pretty standardized naming that should help you find the right model:
|
||||
- "Helsinki-NLP/opus-mt-en-de" => translating from English to German
|
||||
- "Helsinki-NLP/opus-mt-de-en" => translating from German to English
|
||||
- "Helsinki-NLP/opus-mt-fr-en" => translating from French to English
|
||||
- "Helsinki-NLP/opus-mt-hi-en"=> translating from Hindi to English
|
||||
...
|
||||
|
||||
They also have a few multilingual models that support multiple languages at once.
|
||||
|
||||
:param model_name_or_path: Name of the seq2seq model that shall be used for translation.
|
||||
Can be a remote name from Huggingface's modelhub or a local path.
|
||||
:param tokenizer_name: Optional tokenizer name. If not supplied, `model_name_or_path` will also be used for the
|
||||
tokenizer.
|
||||
:param max_seq_len: The maximum sentence length the model accepts. (Optional)
|
||||
:param clean_up_tokenization_spaces: Whether or not to clean up the tokenization spaces. (default True)
|
||||
"""
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
|
||||
tokenizer_name = tokenizer_name or model_name_or_path
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name
|
||||
)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
||||
|
||||
def translate(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
|
||||
dict_key: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.
|
||||
"""
|
||||
if not query and not documents:
|
||||
raise AttributeError("Translator need query or documents to perform translation")
|
||||
|
||||
if query and documents:
|
||||
raise AttributeError("Translator need either query or documents but not both")
|
||||
|
||||
if documents and len(documents) == 0:
|
||||
logger.warning("Empty documents list is passed")
|
||||
return documents
|
||||
|
||||
dict_key = dict_key or "text"
|
||||
|
||||
if isinstance(documents, list):
|
||||
if isinstance(documents[0], Document):
|
||||
text_for_translator = [doc.text for doc in documents] # type: ignore
|
||||
elif isinstance(documents[0], str):
|
||||
text_for_translator = documents # type: ignore
|
||||
else:
|
||||
if not isinstance(documents[0].get(dict_key, None), str): # type: ignore
|
||||
raise AttributeError(f"Dictionary should have {dict_key} key and it's value should be `str` type")
|
||||
text_for_translator = [doc[dict_key] for doc in documents] # type: ignore
|
||||
else:
|
||||
text_for_translator: List[str] = [query] # type: ignore
|
||||
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=text_for_translator,
|
||||
return_tensors="pt",
|
||||
max_length=self.max_seq_len
|
||||
)
|
||||
generated_output = self.model.generate(**batch)
|
||||
translated_texts = self.tokenizer.batch_decode(
|
||||
generated_output,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=self.clean_up_tokenization_spaces
|
||||
)
|
||||
|
||||
if query:
|
||||
return translated_texts[0]
|
||||
elif documents:
|
||||
if isinstance(documents, list) and isinstance(documents[0], str):
|
||||
return [translated_text for translated_text in translated_texts]
|
||||
|
||||
for translated_text, doc in zip(translated_texts, documents):
|
||||
if isinstance(doc, Document):
|
||||
doc.text = translated_text
|
||||
else:
|
||||
doc[dict_key] = translated_text # type: ignore
|
||||
|
||||
return documents
|
||||
|
||||
raise AttributeError("Translator need query or documents to perform translation")
|
||||
@ -23,6 +23,7 @@ from haystack.document_store.sql import SQLDocumentStore
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from haystack.summarizer.transformers import TransformersSummarizer
|
||||
from haystack.translator import TransformersTranslator
|
||||
|
||||
|
||||
def _sql_session_rollback(self, attr):
|
||||
@ -161,6 +162,20 @@ def summarizer():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def en_to_de_translator():
|
||||
return TransformersTranslator(
|
||||
model_name_or_path="Helsinki-NLP/opus-mt-en-de",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def de_to_en_translator():
|
||||
return TransformersTranslator(
|
||||
model_name_or_path="Helsinki-NLP/opus-mt-de-en",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_docs_xs():
|
||||
return [
|
||||
|
||||
@ -2,7 +2,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.pipeline import GenerativeQAPipeline
|
||||
from haystack.pipeline import TranslationWrapperPipeline, GenerativeQAPipeline
|
||||
|
||||
DOCS_WITH_EMBEDDINGS = [
|
||||
Document(
|
||||
@ -426,3 +426,33 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 2
|
||||
assert "berlin" in answers[0]["answer"]
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store) combination to reduce test time
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_generator_pipeline_with_translator(
|
||||
document_store,
|
||||
retriever,
|
||||
rag_generator,
|
||||
en_to_de_translator,
|
||||
de_to_en_translator
|
||||
):
|
||||
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
|
||||
query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?"
|
||||
base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline
|
||||
)
|
||||
output = pipeline.run(query=query, top_k_generator=2, top_k_retriever=1)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 2
|
||||
assert "berlin" in answers[0]["answer"]
|
||||
|
||||
@ -3,7 +3,8 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.pipeline import JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
|
||||
from haystack.pipeline import TranslationWrapperPipeline, JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, \
|
||||
DocumentSearchPipeline
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
|
||||
@ -137,6 +138,27 @@ def test_document_search_pipeline(retriever, document_store):
|
||||
assert len(output["documents"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers_with_translator(reader, retriever_with_docs, en_to_de_translator, de_to_en_translator):
|
||||
base_pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline
|
||||
)
|
||||
|
||||
prediction = pipeline.run(query="Wer lebt in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||
assert prediction is not None
|
||||
assert prediction["query"] == "Wer lebt in Berlin?"
|
||||
assert "Carla" in prediction["answers"][0]["answer"]
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_join_document_pipeline(document_store_with_docs, reader):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.pipeline import SearchSummarizationPipeline
|
||||
from haystack.pipeline import TranslationWrapperPipeline, SearchSummarizationPipeline
|
||||
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
|
||||
|
||||
DOCS = [
|
||||
@ -94,3 +94,41 @@ def test_summarization_pipeline_one_summary(document_store, retriever, summarize
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0]["answer"] in EXPECTED_ONE_SUMMARIES
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store) combination to reduce test time
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.summarizer
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory"), ("elasticsearch", "elasticsearch")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_summarization_pipeline_with_translator(
|
||||
document_store,
|
||||
retriever,
|
||||
summarizer,
|
||||
en_to_de_translator,
|
||||
de_to_en_translator
|
||||
):
|
||||
document_store.write_documents(SPLIT_DOCS)
|
||||
|
||||
if isinstance(retriever, EmbeddingRetriever) or isinstance(retriever, DensePassageRetriever):
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
query = "Wo steht der Eiffelturm?"
|
||||
base_pipeline = SearchSummarizationPipeline(retriever=retriever, summarizer=summarizer)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline
|
||||
)
|
||||
output = pipeline.run(query=query, top_k_retriever=2, generate_single_summary=True)
|
||||
# SearchSummarizationPipeline return answers but Summarizer return documents
|
||||
documents = output["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0].text in [
|
||||
"Der Eiffelturm ist ein Wahrzeichen in Paris, Frankreich.",
|
||||
"Der Eiffelturm, der 1889 in Paris, Frankreich, erbaut wurde, ist das höchste freistehende Bauwerk der Welt."
|
||||
]
|
||||
|
||||
46
test/test_translator.py
Normal file
46
test/test_translator.py
Normal file
@ -0,0 +1,46 @@
|
||||
from haystack import Document
|
||||
|
||||
import pytest
|
||||
|
||||
EXPECTED_OUTPUT = "Ich lebe in Berlin"
|
||||
INPUT = "I live in Berlin"
|
||||
|
||||
|
||||
def test_translator_with_query(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_list(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[INPUT])[0] == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_document(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[Document(text=INPUT)])[0].text == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_dictionary(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[{"text": INPUT}])[0]["text"] == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_empty_input(en_to_de_translator):
|
||||
with pytest.raises(AttributeError):
|
||||
en_to_de_translator.translate()
|
||||
|
||||
|
||||
def test_translator_with_query_and_documents(en_to_de_translator):
|
||||
with pytest.raises(AttributeError):
|
||||
en_to_de_translator.translate(query=INPUT, documents=[INPUT])
|
||||
|
||||
|
||||
def test_translator_with_dict_without_text_key(en_to_de_translator):
|
||||
with pytest.raises(AttributeError):
|
||||
en_to_de_translator.translate(documents=[{"text1": INPUT}])
|
||||
|
||||
|
||||
def test_translator_with_dict_with_non_string_value(en_to_de_translator):
|
||||
with pytest.raises(AttributeError):
|
||||
en_to_de_translator.translate(documents=[{"text": 123}])
|
||||
Loading…
x
Reference in New Issue
Block a user