diff --git a/docs/_src/api/api/translator.md b/docs/_src/api/api/translator.md index 159a6b987..f996b3c5a 100644 --- a/docs/_src/api/api/translator.md +++ b/docs/_src/api/api/translator.md @@ -100,7 +100,7 @@ tokenizer. #### TransformersTranslator.translate ```python -def translate(results: List[Dict[str, Any]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]] +def translate(results: Optional[List[Dict[str, Any]]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]] ``` Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated. diff --git a/haystack/nodes/translator/transformers.py b/haystack/nodes/translator/transformers.py index 217599496..69153debf 100644 --- a/haystack/nodes/translator/transformers.py +++ b/haystack/nodes/translator/transformers.py @@ -1,8 +1,9 @@ import logging +from copy import deepcopy from typing import Any, Dict, List, Optional, Union from tqdm.auto import tqdm -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore from haystack.errors import HaystackError from haystack.schema import Document, Answer @@ -77,7 +78,7 @@ class TransformersTranslator(BaseTranslator): def translate( self, - results: List[Dict[str, Any]] = None, + results: Optional[List[Dict[str, Any]]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None, @@ -95,10 +96,10 @@ class TransformersTranslator(BaseTranslator): queries_for_translator = [result["query"] for result in results] answers_for_translator = [result["answers"][0].answer for result in results] if not query and not documents and results is None: - raise AttributeError("Translator needs query or documents to perform translation.") + raise AttributeError("Translator needs a query or documents to perform translation.") if query and documents: - raise AttributeError("Translator needs either query or documents but not both.") + raise AttributeError("Translator needs either a query or documents but not both.") if documents and len(documents) == 0: logger.warning("Empty documents list is passed") @@ -144,17 +145,22 @@ class TransformersTranslator(BaseTranslator): if isinstance(documents, list) and isinstance(documents[0], str): return [translated_text for translated_text in translated_texts] + translated_documents: Union[ + List[Document], List[Answer], List[str], List[Dict[str, Any]] + ] = [] # type: ignore for translated_text, doc in zip(translated_texts, documents): - if isinstance(doc, Document): - doc.content = translated_text - elif isinstance(doc, Answer): - doc.answer = translated_text + translated_document = deepcopy(doc) + if isinstance(translated_document, Document): + translated_document.content = translated_text + elif isinstance(translated_document, Answer): + translated_document.answer = translated_text else: - doc[dict_key] = translated_text # type: ignore + translated_document[dict_key] = translated_text # type: ignore + translated_documents.append(translated_document) # type: ignore - return documents + return translated_documents - raise AttributeError("Translator need query or documents to perform translation") + raise AttributeError("Translator needs a query or documents to perform translation") def translate_batch( self, @@ -175,11 +181,11 @@ class TransformersTranslator(BaseTranslator): raise AttributeError("Translator needs either query or documents but not both.") if not queries and not documents: - raise AttributeError("Translator needs query or documents to perform translation.") + raise AttributeError("Translator needs a query or documents to perform translation.") + translated = [] # Translate queries if queries: - translated = [] for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"): cur_translation = self.translate(query=query) translated.append(cur_translation) @@ -188,14 +194,13 @@ class TransformersTranslator(BaseTranslator): elif documents: # Single list of documents / answers if not isinstance(documents[0], list): - translated = self.translate(documents=documents) # type: ignore + translated.append(self.translate(documents=documents)) # type: ignore # Multiple lists of document / answer lists else: - translated = [] for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"): if not isinstance(cur_list, list): raise HaystackError( - f"cur_list was of type {type(cur_list)}, but expected a list of " f"Documents / Answers." + f"cur_list was of type {type(cur_list)}, but expected a list of Documents / Answers." ) cur_translation = self.translate(documents=cur_list) translated.append(cur_translation) diff --git a/test/nodes/test_translator.py b/test/nodes/test_translator.py index cba87a2ce..e1f2f478a 100644 --- a/test/nodes/test_translator.py +++ b/test/nodes/test_translator.py @@ -5,6 +5,8 @@ import pytest EXPECTED_OUTPUT = "Ich lebe in Berlin" INPUT = "I live in Berlin" +DOCUMENT_INPUT = Document(content=INPUT) + def test_translator_with_query(en_to_de_translator): assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT @@ -18,10 +20,22 @@ def test_translator_with_document(en_to_de_translator): assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT +def test_translator_with_document_preserves_input(en_to_de_translator): + original_document = Document(content=INPUT) + en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned + assert original_document.content == INPUT + + def test_translator_with_dictionary(en_to_de_translator): assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT +def test_translator_with_dictionary_preserves_input(en_to_de_translator): + original_document = {"content": INPUT} + en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned + assert original_document["content"] == INPUT + + def test_translator_with_dictionary_with_dict_key(en_to_de_translator): assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT