fix(translator): write translated text to output documents, while keeping input untouched (#3077)

* Set translated text on a copy of original document

* Return new translated list

* Manually generated docs

TODO: check pre-commit

* Hook generated file

* Rename variables for better maintenance

* fix(translator): prevent inputs from being changed

* fix: manual update translator docs

* style(translator): explicit type declaration on List

* docs(translator): re-run pre-commit hook

* style(translator): ignore mypy wrong type check

* docs(translator): re-run pre-commit hook
This commit is contained in:
Daniel Bichuetti 2022-08-22 05:07:05 -03:00 committed by GitHub
parent bc6f71b5ba
commit d5e36ce6b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 17 deletions

View File

@ -100,7 +100,7 @@ tokenizer.
#### TransformersTranslator.translate #### TransformersTranslator.translate
```python ```python
def translate(results: List[Dict[str, Any]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]] def translate(results: Optional[List[Dict[str, Any]]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
``` ```
Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated. Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.

View File

@ -1,8 +1,9 @@
import logging import logging
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
from haystack.errors import HaystackError from haystack.errors import HaystackError
from haystack.schema import Document, Answer from haystack.schema import Document, Answer
@ -77,7 +78,7 @@ class TransformersTranslator(BaseTranslator):
def translate( def translate(
self, self,
results: List[Dict[str, Any]] = None, results: Optional[List[Dict[str, Any]]] = None,
query: Optional[str] = None, query: Optional[str] = None,
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
dict_key: Optional[str] = None, dict_key: Optional[str] = None,
@ -95,10 +96,10 @@ class TransformersTranslator(BaseTranslator):
queries_for_translator = [result["query"] for result in results] queries_for_translator = [result["query"] for result in results]
answers_for_translator = [result["answers"][0].answer for result in results] answers_for_translator = [result["answers"][0].answer for result in results]
if not query and not documents and results is None: if not query and not documents and results is None:
raise AttributeError("Translator needs query or documents to perform translation.") raise AttributeError("Translator needs a query or documents to perform translation.")
if query and documents: if query and documents:
raise AttributeError("Translator needs either query or documents but not both.") raise AttributeError("Translator needs either a query or documents but not both.")
if documents and len(documents) == 0: if documents and len(documents) == 0:
logger.warning("Empty documents list is passed") logger.warning("Empty documents list is passed")
@ -144,17 +145,22 @@ class TransformersTranslator(BaseTranslator):
if isinstance(documents, list) and isinstance(documents[0], str): if isinstance(documents, list) and isinstance(documents[0], str):
return [translated_text for translated_text in translated_texts] return [translated_text for translated_text in translated_texts]
translated_documents: Union[
List[Document], List[Answer], List[str], List[Dict[str, Any]]
] = [] # type: ignore
for translated_text, doc in zip(translated_texts, documents): for translated_text, doc in zip(translated_texts, documents):
if isinstance(doc, Document): translated_document = deepcopy(doc)
doc.content = translated_text if isinstance(translated_document, Document):
elif isinstance(doc, Answer): translated_document.content = translated_text
doc.answer = translated_text elif isinstance(translated_document, Answer):
translated_document.answer = translated_text
else: else:
doc[dict_key] = translated_text # type: ignore translated_document[dict_key] = translated_text # type: ignore
translated_documents.append(translated_document) # type: ignore
return documents return translated_documents
raise AttributeError("Translator need query or documents to perform translation") raise AttributeError("Translator needs a query or documents to perform translation")
def translate_batch( def translate_batch(
self, self,
@ -175,11 +181,11 @@ class TransformersTranslator(BaseTranslator):
raise AttributeError("Translator needs either query or documents but not both.") raise AttributeError("Translator needs either query or documents but not both.")
if not queries and not documents: if not queries and not documents:
raise AttributeError("Translator needs query or documents to perform translation.") raise AttributeError("Translator needs a query or documents to perform translation.")
translated = []
# Translate queries # Translate queries
if queries: if queries:
translated = []
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"): for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
cur_translation = self.translate(query=query) cur_translation = self.translate(query=query)
translated.append(cur_translation) translated.append(cur_translation)
@ -188,14 +194,13 @@ class TransformersTranslator(BaseTranslator):
elif documents: elif documents:
# Single list of documents / answers # Single list of documents / answers
if not isinstance(documents[0], list): if not isinstance(documents[0], list):
translated = self.translate(documents=documents) # type: ignore translated.append(self.translate(documents=documents)) # type: ignore
# Multiple lists of document / answer lists # Multiple lists of document / answer lists
else: else:
translated = []
for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"): for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"):
if not isinstance(cur_list, list): if not isinstance(cur_list, list):
raise HaystackError( raise HaystackError(
f"cur_list was of type {type(cur_list)}, but expected a list of " f"Documents / Answers." f"cur_list was of type {type(cur_list)}, but expected a list of Documents / Answers."
) )
cur_translation = self.translate(documents=cur_list) cur_translation = self.translate(documents=cur_list)
translated.append(cur_translation) translated.append(cur_translation)

View File

@ -5,6 +5,8 @@ import pytest
EXPECTED_OUTPUT = "Ich lebe in Berlin" EXPECTED_OUTPUT = "Ich lebe in Berlin"
INPUT = "I live in Berlin" INPUT = "I live in Berlin"
DOCUMENT_INPUT = Document(content=INPUT)
def test_translator_with_query(en_to_de_translator): def test_translator_with_query(en_to_de_translator):
assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
@ -18,10 +20,22 @@ def test_translator_with_document(en_to_de_translator):
assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT
def test_translator_with_document_preserves_input(en_to_de_translator):
original_document = Document(content=INPUT)
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
assert original_document.content == INPUT
def test_translator_with_dictionary(en_to_de_translator): def test_translator_with_dictionary(en_to_de_translator):
assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT
def test_translator_with_dictionary_preserves_input(en_to_de_translator):
original_document = {"content": INPUT}
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
assert original_document["content"] == INPUT
def test_translator_with_dictionary_with_dict_key(en_to_de_translator): def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT