fix(translator): write translated text to output documents, while keeping input untouched (#3077)

* Set translated text on a copy of original document

* Return new translated list

* Manually generated docs

TODO: check pre-commit

* Hook generated file

* Rename variables for better maintenance

* fix(translator): prevent inputs from being changed

* fix: manual update translator docs

* style(translator): explicit type declaration on List

* docs(translator): re-run pre-commit hook

* style(translator): ignore mypy wrong type check

* docs(translator): re-run pre-commit hook
This commit is contained in:
Daniel Bichuetti 2022-08-22 05:07:05 -03:00 committed by GitHub
parent bc6f71b5ba
commit d5e36ce6b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 17 deletions

View File

@ -100,7 +100,7 @@ tokenizer.
#### TransformersTranslator.translate
```python
def translate(results: List[Dict[str, Any]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
def translate(results: Optional[List[Dict[str, Any]]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
```
Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.

View File

@ -1,8 +1,9 @@
import logging
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
from haystack.errors import HaystackError
from haystack.schema import Document, Answer
@ -77,7 +78,7 @@ class TransformersTranslator(BaseTranslator):
def translate(
self,
results: List[Dict[str, Any]] = None,
results: Optional[List[Dict[str, Any]]] = None,
query: Optional[str] = None,
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
dict_key: Optional[str] = None,
@ -95,10 +96,10 @@ class TransformersTranslator(BaseTranslator):
queries_for_translator = [result["query"] for result in results]
answers_for_translator = [result["answers"][0].answer for result in results]
if not query and not documents and results is None:
raise AttributeError("Translator needs query or documents to perform translation.")
raise AttributeError("Translator needs a query or documents to perform translation.")
if query and documents:
raise AttributeError("Translator needs either query or documents but not both.")
raise AttributeError("Translator needs either a query or documents but not both.")
if documents and len(documents) == 0:
logger.warning("Empty documents list is passed")
@ -144,17 +145,22 @@ class TransformersTranslator(BaseTranslator):
if isinstance(documents, list) and isinstance(documents[0], str):
return [translated_text for translated_text in translated_texts]
translated_documents: Union[
List[Document], List[Answer], List[str], List[Dict[str, Any]]
] = [] # type: ignore
for translated_text, doc in zip(translated_texts, documents):
if isinstance(doc, Document):
doc.content = translated_text
elif isinstance(doc, Answer):
doc.answer = translated_text
translated_document = deepcopy(doc)
if isinstance(translated_document, Document):
translated_document.content = translated_text
elif isinstance(translated_document, Answer):
translated_document.answer = translated_text
else:
doc[dict_key] = translated_text # type: ignore
translated_document[dict_key] = translated_text # type: ignore
translated_documents.append(translated_document) # type: ignore
return documents
return translated_documents
raise AttributeError("Translator need query or documents to perform translation")
raise AttributeError("Translator needs a query or documents to perform translation")
def translate_batch(
self,
@ -175,11 +181,11 @@ class TransformersTranslator(BaseTranslator):
raise AttributeError("Translator needs either query or documents but not both.")
if not queries and not documents:
raise AttributeError("Translator needs query or documents to perform translation.")
raise AttributeError("Translator needs a query or documents to perform translation.")
translated = []
# Translate queries
if queries:
translated = []
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
cur_translation = self.translate(query=query)
translated.append(cur_translation)
@ -188,14 +194,13 @@ class TransformersTranslator(BaseTranslator):
elif documents:
# Single list of documents / answers
if not isinstance(documents[0], list):
translated = self.translate(documents=documents) # type: ignore
translated.append(self.translate(documents=documents)) # type: ignore
# Multiple lists of document / answer lists
else:
translated = []
for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"):
if not isinstance(cur_list, list):
raise HaystackError(
f"cur_list was of type {type(cur_list)}, but expected a list of " f"Documents / Answers."
f"cur_list was of type {type(cur_list)}, but expected a list of Documents / Answers."
)
cur_translation = self.translate(documents=cur_list)
translated.append(cur_translation)

View File

@ -5,6 +5,8 @@ import pytest
EXPECTED_OUTPUT = "Ich lebe in Berlin"
INPUT = "I live in Berlin"
DOCUMENT_INPUT = Document(content=INPUT)
def test_translator_with_query(en_to_de_translator):
assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
@ -18,10 +20,22 @@ def test_translator_with_document(en_to_de_translator):
assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT
def test_translator_with_document_preserves_input(en_to_de_translator):
original_document = Document(content=INPUT)
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
assert original_document.content == INPUT
def test_translator_with_dictionary(en_to_de_translator):
assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT
def test_translator_with_dictionary_preserves_input(en_to_de_translator):
original_document = {"content": INPUT}
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
assert original_document["content"] == INPUT
def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT