diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md
index de88f5a1c..1d0f963da 100644
--- a/docs/_src/api/api/pipelines.md
+++ b/docs/_src/api/api/pipelines.md
@@ -5,7 +5,7 @@
## Pipeline Objects
```python
-class Pipeline()
+class Pipeline(ABC)
```
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
@@ -131,7 +131,7 @@ variable 'MYDOCSTORE_PARAMS_INDEX=documents-2021' can be set. Note that an
## BaseStandardPipeline Objects
```python
-class BaseStandardPipeline()
+class BaseStandardPipeline(ABC)
```
@@ -316,6 +316,32 @@ Initialize a Pipeline for finding similar FAQs using semantic document search.
- `retriever`: Retriever instance
+
+## TranslationWrapperPipeline Objects
+
+```python
+class TranslationWrapperPipeline(BaseStandardPipeline)
+```
+
+Takes an existing search pipeline and adds one "input translation node" after the Query and one
+"output translation" node just before returning the results
+
+
+#### \_\_init\_\_
+
+```python
+ | __init__(input_translator: BaseTranslator, output_translator: BaseTranslator, pipeline: BaseStandardPipeline)
+```
+
+Wrap a given `pipeline` with the `input_translator` and `output_translator`.
+
+**Arguments**:
+
+- `input_translator`: A Translator node that shall translate the input query from language A to B
+- `output_translator`: A Translator node that shall translate the pipeline results from language B to A
+- `pipeline`: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap".
+Note that pipelines with split or merge nodes are currently not supported.
+
## JoinDocuments Objects
diff --git a/docs/_src/usage/usage/translator.md b/docs/_src/usage/usage/translator.md
new file mode 100644
index 000000000..d7b71dadf
--- /dev/null
+++ b/docs/_src/usage/usage/translator.md
@@ -0,0 +1,58 @@
+
+
+# Translator
+
+Texts come in different languages. This is not different for search and there are plenty of options to deal with it.
+One of them is actually to translate the incoming query, the documents or the search results.
+
+Let's imagine you have an English corpus of technical docs, but the mother tongue of many of your users is French.
+You can use a Translator node in your pipeline to
+1. Translate the incoming query from French to English
+2. Search in your English corpus for the right document / answer
+3. Translate the results back from English to French
+
+
+
+**Example (Stand-alone Translator)**
+
+You can use the Translator component directly to translate your query or document(s):
+```python
+DOCS = [
+ Document(
+ text="""Heinz von Foerster was an Austrian American scientist
+ combining physics and philosophy, and widely attributed
+ as the originator of Second-order cybernetics."""
+ )
+ ]
+translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr")
+res = translator.translate(documents=DOCS, query=None)
+```
+
+**Example (Wrapping another Pipeline)**
+
+You can also wrap one of your existing pipelines and "add" the translation nodes at the beginning and at the end of your pipeline.
+For example, lets translate the incoming query to from French to English, then do our document retrieval and then translate the results back from English to French:
+
+```python
+from haystack.pipeline import TranslationWrapperPipeline, DocumentSearchPipeline
+from haystack.translator import TransformersTranslator
+
+pipeline = DocumentSearchPipeline(retriever=my_dpr_retriever)
+
+in_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-fr-en")
+out_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-fr")
+
+pipeline_with_translation = TranslationWrapperPipeline(input_translator=in_translator,
+ output_translator=out_translator,
+ pipeline=pipeline)
+```
+
+
+
diff --git a/haystack/pipeline.py b/haystack/pipeline.py
index f05eb7a6c..a45394669 100644
--- a/haystack/pipeline.py
+++ b/haystack/pipeline.py
@@ -1,3 +1,4 @@
+from abc import ABC
import os
from copy import deepcopy
from pathlib import Path
@@ -13,9 +14,10 @@ from haystack.generator.base import BaseGenerator
from haystack.reader.base import BaseReader
from haystack.retriever.base import BaseRetriever
from haystack.summarizer.base import BaseSummarizer
+from haystack.translator.base import BaseTranslator
-class Pipeline:
+class Pipeline(ABC):
"""
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
@@ -45,7 +47,7 @@ class Pipeline:
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
must be specified explicitly as "QueryClassifier.output_2".
"""
- self.graph.add_node(name, component=component)
+ self.graph.add_node(name, component=component, inputs=inputs)
for i in inputs:
if "." in i:
@@ -93,7 +95,7 @@ class Pipeline:
while has_next_node:
output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict)
input_dict = output_dict
- next_nodes = self._get_next_nodes(current_node_id, stream_id)
+ next_nodes = self.get_next_nodes(current_node_id, stream_id)
if len(next_nodes) > 1:
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
@@ -114,7 +116,7 @@ class Pipeline:
return output_dict
- def _get_next_nodes(self, node_id: str, stream_id: str):
+ def get_next_nodes(self, node_id: str, stream_id: str):
current_node_edges = self.graph.edges(node_id, data=True)
next_nodes = [
next_node
@@ -259,7 +261,7 @@ class Pipeline:
definition["params"][param_name] = value
-class BaseStandardPipeline:
+class BaseStandardPipeline(ABC):
pipeline: Pipeline
def add_node(self, component, name: str, inputs: List[str]):
@@ -451,6 +453,52 @@ class FAQPipeline(BaseStandardPipeline):
return results
+class TranslationWrapperPipeline(BaseStandardPipeline):
+
+ """
+ Takes an existing search pipeline and adds one "input translation node" after the Query and one
+ "output translation" node just before returning the results
+ """
+
+ def __init__(
+ self,
+ input_translator: BaseTranslator,
+ output_translator: BaseTranslator,
+ pipeline: BaseStandardPipeline
+ ):
+ """
+ Wrap a given `pipeline` with the `input_translator` and `output_translator`.
+
+ :param input_translator: A Translator node that shall translate the input query from language A to B
+ :param output_translator: A Translator node that shall translate the pipeline results from language B to A
+ :param pipeline: The pipeline object (e.g. ExtractiveQAPipeline) you want to "wrap".
+ Note that pipelines with split or merge nodes are currently not supported.
+ """
+
+ self.pipeline = Pipeline()
+ self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])
+
+ graph = pipeline.pipeline.graph
+ previous_node_name = ["InputTranslator"]
+ # Traverse in BFS
+ for node in graph.nodes:
+ if node == "Query":
+ continue
+
+ # TODO: Do not work properly for Join Node and Answer format
+ if graph.nodes[node]["inputs"] and len(graph.nodes[node]["inputs"]) > 1:
+ raise AttributeError("Split and merge nodes are not supported currently")
+
+ self.pipeline.add_node(name=node, component=graph.nodes[node]["component"], inputs=previous_node_name)
+ previous_node_name = [node]
+
+ self.pipeline.add_node(component=output_translator, name="OutputTranslator", inputs=previous_node_name)
+
+ def run(self, **kwargs):
+ output = self.pipeline.run(**kwargs)
+ return output
+
+
class QueryNode:
outgoing_edges = 1
diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py
index 90a619481..71674bd71 100644
--- a/haystack/retriever/dense.py
+++ b/haystack/retriever/dense.py
@@ -6,8 +6,6 @@ from pathlib import Path
from tqdm import tqdm
from haystack.document_store.base import BaseDocumentStore
-from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
-from haystack.document_store.memory import InMemoryDocumentStore
from haystack import Document
from haystack.retriever.base import BaseRetriever
diff --git a/haystack/summarizer/transformers.py b/haystack/summarizer/transformers.py
index 8f6f9d986..6a548c24d 100644
--- a/haystack/summarizer/transformers.py
+++ b/haystack/summarizer/transformers.py
@@ -1,9 +1,8 @@
import logging
-from typing import Any, Dict, List, Optional
+from typing import List, Optional
from transformers import pipeline
from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM
-from transformers import AutoTokenizer
from haystack import Document
from haystack.summarizer.base import BaseSummarizer
diff --git a/haystack/translator/__init__.py b/haystack/translator/__init__.py
new file mode 100644
index 000000000..c4c6ee277
--- /dev/null
+++ b/haystack/translator/__init__.py
@@ -0,0 +1 @@
+from haystack.translator.transformers import TransformersTranslator
diff --git a/haystack/translator/base.py b/haystack/translator/base.py
new file mode 100644
index 000000000..d6c75efce
--- /dev/null
+++ b/haystack/translator/base.py
@@ -0,0 +1,58 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Mapping, Optional, Union
+
+from haystack import Document
+
+
+class BaseTranslator(ABC):
+ """
+ Abstract class for a Translator component that translates either a query or a doc from language A to language B.
+ """
+
+ outgoing_edges = 1
+
+ @abstractmethod
+ def translate(
+ self,
+ query: Optional[str] = None,
+ documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
+ dict_key: Optional[str] = None,
+ **kwargs
+ ) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]:
+ """
+ Translate the passed query or a list of documents from language A to B.
+ """
+ pass
+
+ def run(
+ self,
+ query: Optional[str] = None,
+ documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
+ answers: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
+ dict_key: Optional[str] = None,
+ **kwargs
+ ):
+ """Method that gets executed when this class is used as a Node in a Haystack Pipeline"""
+
+ results: Dict = {
+ **kwargs
+ }
+
+ # This will cover input query stage
+ if query:
+ results["query"] = self.translate(query=query)
+ # This will cover retriever and summarizer
+ if documents:
+ dict_key = dict_key or "text"
+ results["documents"] = self.translate(documents=documents, dict_key=dict_key)
+
+ if answers:
+ dict_key = dict_key or "answer"
+ if isinstance(answers, Mapping):
+ # This will cover reader
+ results["answers"] = self.translate(documents=answers["answers"], dict_key=dict_key)
+ else:
+ # This will cover generator
+ results["answers"] = self.translate(documents=answers, dict_key=dict_key)
+
+ return results, "output_1"
\ No newline at end of file
diff --git a/haystack/translator/transformers.py b/haystack/translator/transformers.py
new file mode 100644
index 000000000..de90a760c
--- /dev/null
+++ b/haystack/translator/transformers.py
@@ -0,0 +1,127 @@
+import logging
+from typing import Any, Dict, List, Optional, Union
+
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+
+from haystack import Document
+from haystack.translator.base import BaseTranslator
+
+logger = logging.getLogger(__name__)
+
+
+class TransformersTranslator(BaseTranslator):
+ """
+ Translator component based on Seq2Seq models from Huggingface's transformers library.
+ Exemplary use cases:
+ - Translate a query from Language A to B (e.g. if you only have good models + documents in language B)
+ - Translate a document from Language A to B (e.g. if you want to return results in the native language of the user)
+
+ We currently recommend using OPUS models (see __init__() for details)
+
+ **Example:**
+
+ ```python
+ | DOCS = [
+ | Document(text="Heinz von Foerster was an Austrian American scientist combining physics and philosophy,
+ | and widely attributed as the originator of Second-order cybernetics.")
+ | ]
+ | translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-en-de")
+ | res = translator.translate(documents=DOCS, query=None)
+ ```
+ """
+ def __init__(
+ self,
+ model_name_or_path: str,
+ tokenizer_name: Optional[str] = None,
+ max_seq_len: Optional[int] = None,
+ clean_up_tokenization_spaces: Optional[bool] = True
+ ):
+ """ Initialize the translator with a model that fits your targeted languages. While we support all seq2seq
+ models from Hugging Face's model hub, we recommend using the OPUS models from Helsiniki NLP. They provide plenty
+ of different models, usually one model per language pair and translation direction.
+ They have a pretty standardized naming that should help you find the right model:
+ - "Helsinki-NLP/opus-mt-en-de" => translating from English to German
+ - "Helsinki-NLP/opus-mt-de-en" => translating from German to English
+ - "Helsinki-NLP/opus-mt-fr-en" => translating from French to English
+ - "Helsinki-NLP/opus-mt-hi-en"=> translating from Hindi to English
+ ...
+
+ They also have a few multilingual models that support multiple languages at once.
+
+ :param model_name_or_path: Name of the seq2seq model that shall be used for translation.
+ Can be a remote name from Huggingface's modelhub or a local path.
+ :param tokenizer_name: Optional tokenizer name. If not supplied, `model_name_or_path` will also be used for the
+ tokenizer.
+ :param max_seq_len: The maximum sentence length the model accepts. (Optional)
+ :param clean_up_tokenization_spaces: Whether or not to clean up the tokenization spaces. (default True)
+ """
+
+ self.max_seq_len = max_seq_len
+ self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
+ tokenizer_name = tokenizer_name or model_name_or_path
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name
+ )
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
+
+ def translate(
+ self,
+ query: Optional[str] = None,
+ documents: Optional[Union[List[Document], List[str], List[Dict[str, Any]]]] = None,
+ dict_key: Optional[str] = None,
+ **kwargs
+ ) -> Union[str, List[Document], List[str], List[Dict[str, Any]]]:
+ """
+ Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.
+ """
+ if not query and not documents:
+ raise AttributeError("Translator need query or documents to perform translation")
+
+ if query and documents:
+ raise AttributeError("Translator need either query or documents but not both")
+
+ if documents and len(documents) == 0:
+ logger.warning("Empty documents list is passed")
+ return documents
+
+ dict_key = dict_key or "text"
+
+ if isinstance(documents, list):
+ if isinstance(documents[0], Document):
+ text_for_translator = [doc.text for doc in documents] # type: ignore
+ elif isinstance(documents[0], str):
+ text_for_translator = documents # type: ignore
+ else:
+ if not isinstance(documents[0].get(dict_key, None), str): # type: ignore
+ raise AttributeError(f"Dictionary should have {dict_key} key and it's value should be `str` type")
+ text_for_translator = [doc[dict_key] for doc in documents] # type: ignore
+ else:
+ text_for_translator: List[str] = [query] # type: ignore
+
+ batch = self.tokenizer.prepare_seq2seq_batch(
+ src_texts=text_for_translator,
+ return_tensors="pt",
+ max_length=self.max_seq_len
+ )
+ generated_output = self.model.generate(**batch)
+ translated_texts = self.tokenizer.batch_decode(
+ generated_output,
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=self.clean_up_tokenization_spaces
+ )
+
+ if query:
+ return translated_texts[0]
+ elif documents:
+ if isinstance(documents, list) and isinstance(documents[0], str):
+ return [translated_text for translated_text in translated_texts]
+
+ for translated_text, doc in zip(translated_texts, documents):
+ if isinstance(doc, Document):
+ doc.text = translated_text
+ else:
+ doc[dict_key] = translated_text # type: ignore
+
+ return documents
+
+ raise AttributeError("Translator need query or documents to perform translation")
diff --git a/test/conftest.py b/test/conftest.py
index 8aefa3945..89fb0d0a8 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -23,6 +23,7 @@ from haystack.document_store.sql import SQLDocumentStore
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.summarizer.transformers import TransformersSummarizer
+from haystack.translator import TransformersTranslator
def _sql_session_rollback(self, attr):
@@ -161,6 +162,20 @@ def summarizer():
)
+@pytest.fixture(scope="module")
+def en_to_de_translator():
+ return TransformersTranslator(
+ model_name_or_path="Helsinki-NLP/opus-mt-en-de",
+ )
+
+
+@pytest.fixture(scope="module")
+def de_to_en_translator():
+ return TransformersTranslator(
+ model_name_or_path="Helsinki-NLP/opus-mt-de-en",
+ )
+
+
@pytest.fixture(scope="module")
def test_docs_xs():
return [
diff --git a/test/test_generator.py b/test/test_generator.py
index bf8a6083a..e4715ead6 100644
--- a/test/test_generator.py
+++ b/test/test_generator.py
@@ -2,7 +2,7 @@ import numpy as np
import pytest
from haystack import Document
-from haystack.pipeline import GenerativeQAPipeline
+from haystack.pipeline import TranslationWrapperPipeline, GenerativeQAPipeline
DOCS_WITH_EMBEDDINGS = [
Document(
@@ -426,3 +426,33 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
answers = output["answers"]
assert len(answers) == 2
assert "berlin" in answers[0]["answer"]
+
+
+# Keeping few (retriever,document_store) combination to reduce test time
+@pytest.mark.slow
+@pytest.mark.generator
+@pytest.mark.elasticsearch
+@pytest.mark.parametrize(
+ "retriever,document_store",
+ [("embedding", "memory"), ("elasticsearch", "elasticsearch")],
+ indirect=True,
+)
+def test_generator_pipeline_with_translator(
+ document_store,
+ retriever,
+ rag_generator,
+ en_to_de_translator,
+ de_to_en_translator
+):
+ document_store.write_documents(DOCS_WITH_EMBEDDINGS)
+ query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?"
+ base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
+ pipeline = TranslationWrapperPipeline(
+ input_translator=de_to_en_translator,
+ output_translator=en_to_de_translator,
+ pipeline=base_pipeline
+ )
+ output = pipeline.run(query=query, top_k_generator=2, top_k_retriever=1)
+ answers = output["answers"]
+ assert len(answers) == 2
+ assert "berlin" in answers[0]["answer"]
diff --git a/test/test_pipeline.py b/test/test_pipeline.py
index 7cad05d9f..68abaf910 100644
--- a/test/test_pipeline.py
+++ b/test/test_pipeline.py
@@ -3,7 +3,8 @@ from pathlib import Path
import pytest
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
-from haystack.pipeline import JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, DocumentSearchPipeline
+from haystack.pipeline import TranslationWrapperPipeline, JoinDocuments, ExtractiveQAPipeline, Pipeline, FAQPipeline, \
+ DocumentSearchPipeline
from haystack.retriever.dense import DensePassageRetriever
from haystack.retriever.sparse import ElasticsearchRetriever
@@ -137,6 +138,27 @@ def test_document_search_pipeline(retriever, document_store):
assert len(output["documents"]) == 1
+@pytest.mark.slow
+@pytest.mark.elasticsearch
+@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
+def test_extractive_qa_answers_with_translator(reader, retriever_with_docs, en_to_de_translator, de_to_en_translator):
+ base_pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
+ pipeline = TranslationWrapperPipeline(
+ input_translator=de_to_en_translator,
+ output_translator=en_to_de_translator,
+ pipeline=base_pipeline
+ )
+
+ prediction = pipeline.run(query="Wer lebt in Berlin?", top_k_retriever=10, top_k_reader=3)
+ assert prediction is not None
+ assert prediction["query"] == "Wer lebt in Berlin?"
+ assert "Carla" in prediction["answers"][0]["answer"]
+ assert prediction["answers"][0]["probability"] <= 1
+ assert prediction["answers"][0]["probability"] >= 0
+ assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
+ assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
+
+
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_join_document_pipeline(document_store_with_docs, reader):
diff --git a/test/test_summarizer.py b/test/test_summarizer.py
index 0f9a6527e..17747cc1a 100644
--- a/test/test_summarizer.py
+++ b/test/test_summarizer.py
@@ -1,7 +1,7 @@
import pytest
from haystack import Document
-from haystack.pipeline import SearchSummarizationPipeline
+from haystack.pipeline import TranslationWrapperPipeline, SearchSummarizationPipeline
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
DOCS = [
@@ -94,3 +94,41 @@ def test_summarization_pipeline_one_summary(document_store, retriever, summarize
answers = output["answers"]
assert len(answers) == 1
assert answers[0]["answer"] in EXPECTED_ONE_SUMMARIES
+
+
+# Keeping few (retriever,document_store) combination to reduce test time
+@pytest.mark.slow
+@pytest.mark.elasticsearch
+@pytest.mark.summarizer
+@pytest.mark.parametrize(
+ "retriever,document_store",
+ [("embedding", "memory"), ("elasticsearch", "elasticsearch")],
+ indirect=True,
+)
+def test_summarization_pipeline_with_translator(
+ document_store,
+ retriever,
+ summarizer,
+ en_to_de_translator,
+ de_to_en_translator
+):
+ document_store.write_documents(SPLIT_DOCS)
+
+ if isinstance(retriever, EmbeddingRetriever) or isinstance(retriever, DensePassageRetriever):
+ document_store.update_embeddings(retriever=retriever)
+
+ query = "Wo steht der Eiffelturm?"
+ base_pipeline = SearchSummarizationPipeline(retriever=retriever, summarizer=summarizer)
+ pipeline = TranslationWrapperPipeline(
+ input_translator=de_to_en_translator,
+ output_translator=en_to_de_translator,
+ pipeline=base_pipeline
+ )
+ output = pipeline.run(query=query, top_k_retriever=2, generate_single_summary=True)
+ # SearchSummarizationPipeline return answers but Summarizer return documents
+ documents = output["documents"]
+ assert len(documents) == 1
+ assert documents[0].text in [
+ "Der Eiffelturm ist ein Wahrzeichen in Paris, Frankreich.",
+ "Der Eiffelturm, der 1889 in Paris, Frankreich, erbaut wurde, ist das höchste freistehende Bauwerk der Welt."
+ ]
diff --git a/test/test_translator.py b/test/test_translator.py
new file mode 100644
index 000000000..803221f2b
--- /dev/null
+++ b/test/test_translator.py
@@ -0,0 +1,46 @@
+from haystack import Document
+
+import pytest
+
+EXPECTED_OUTPUT = "Ich lebe in Berlin"
+INPUT = "I live in Berlin"
+
+
+def test_translator_with_query(en_to_de_translator):
+ assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
+
+
+def test_translator_with_list(en_to_de_translator):
+ assert en_to_de_translator.translate(documents=[INPUT])[0] == EXPECTED_OUTPUT
+
+
+def test_translator_with_document(en_to_de_translator):
+ assert en_to_de_translator.translate(documents=[Document(text=INPUT)])[0].text == EXPECTED_OUTPUT
+
+
+def test_translator_with_dictionary(en_to_de_translator):
+ assert en_to_de_translator.translate(documents=[{"text": INPUT}])[0]["text"] == EXPECTED_OUTPUT
+
+
+def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
+ assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT
+
+
+def test_translator_with_empty_input(en_to_de_translator):
+ with pytest.raises(AttributeError):
+ en_to_de_translator.translate()
+
+
+def test_translator_with_query_and_documents(en_to_de_translator):
+ with pytest.raises(AttributeError):
+ en_to_de_translator.translate(query=INPUT, documents=[INPUT])
+
+
+def test_translator_with_dict_without_text_key(en_to_de_translator):
+ with pytest.raises(AttributeError):
+ en_to_de_translator.translate(documents=[{"text1": INPUT}])
+
+
+def test_translator_with_dict_with_non_string_value(en_to_de_translator):
+ with pytest.raises(AttributeError):
+ en_to_de_translator.translate(documents=[{"text": 123}])