mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-15 18:08:40 +00:00
fix(translator): write translated text to output documents, while keeping input untouched (#3077)
* Set translated text on a copy of original document * Return new translated list * Manually generated docs TODO: check pre-commit * Hook generated file * Rename variables for better maintenance * fix(translator): prevent inputs from being changed * fix: manual update translator docs * style(translator): explicit type declaration on List * docs(translator): re-run pre-commit hook * style(translator): ignore mypy wrong type check * docs(translator): re-run pre-commit hook
This commit is contained in:
parent
bc6f71b5ba
commit
d5e36ce6b4
@ -100,7 +100,7 @@ tokenizer.
|
|||||||
#### TransformersTranslator.translate
|
#### TransformersTranslator.translate
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def translate(results: List[Dict[str, Any]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
|
def translate(results: Optional[List[Dict[str, Any]]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
|
||||||
```
|
```
|
||||||
|
|
||||||
Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.
|
Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
|
||||||
|
|
||||||
from haystack.errors import HaystackError
|
from haystack.errors import HaystackError
|
||||||
from haystack.schema import Document, Answer
|
from haystack.schema import Document, Answer
|
||||||
@ -77,7 +78,7 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
|
|
||||||
def translate(
|
def translate(
|
||||||
self,
|
self,
|
||||||
results: List[Dict[str, Any]] = None,
|
results: Optional[List[Dict[str, Any]]] = None,
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
|
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
|
||||||
dict_key: Optional[str] = None,
|
dict_key: Optional[str] = None,
|
||||||
@ -95,10 +96,10 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
queries_for_translator = [result["query"] for result in results]
|
queries_for_translator = [result["query"] for result in results]
|
||||||
answers_for_translator = [result["answers"][0].answer for result in results]
|
answers_for_translator = [result["answers"][0].answer for result in results]
|
||||||
if not query and not documents and results is None:
|
if not query and not documents and results is None:
|
||||||
raise AttributeError("Translator needs query or documents to perform translation.")
|
raise AttributeError("Translator needs a query or documents to perform translation.")
|
||||||
|
|
||||||
if query and documents:
|
if query and documents:
|
||||||
raise AttributeError("Translator needs either query or documents but not both.")
|
raise AttributeError("Translator needs either a query or documents but not both.")
|
||||||
|
|
||||||
if documents and len(documents) == 0:
|
if documents and len(documents) == 0:
|
||||||
logger.warning("Empty documents list is passed")
|
logger.warning("Empty documents list is passed")
|
||||||
@ -144,17 +145,22 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
if isinstance(documents, list) and isinstance(documents[0], str):
|
if isinstance(documents, list) and isinstance(documents[0], str):
|
||||||
return [translated_text for translated_text in translated_texts]
|
return [translated_text for translated_text in translated_texts]
|
||||||
|
|
||||||
|
translated_documents: Union[
|
||||||
|
List[Document], List[Answer], List[str], List[Dict[str, Any]]
|
||||||
|
] = [] # type: ignore
|
||||||
for translated_text, doc in zip(translated_texts, documents):
|
for translated_text, doc in zip(translated_texts, documents):
|
||||||
if isinstance(doc, Document):
|
translated_document = deepcopy(doc)
|
||||||
doc.content = translated_text
|
if isinstance(translated_document, Document):
|
||||||
elif isinstance(doc, Answer):
|
translated_document.content = translated_text
|
||||||
doc.answer = translated_text
|
elif isinstance(translated_document, Answer):
|
||||||
|
translated_document.answer = translated_text
|
||||||
else:
|
else:
|
||||||
doc[dict_key] = translated_text # type: ignore
|
translated_document[dict_key] = translated_text # type: ignore
|
||||||
|
translated_documents.append(translated_document) # type: ignore
|
||||||
|
|
||||||
return documents
|
return translated_documents
|
||||||
|
|
||||||
raise AttributeError("Translator need query or documents to perform translation")
|
raise AttributeError("Translator needs a query or documents to perform translation")
|
||||||
|
|
||||||
def translate_batch(
|
def translate_batch(
|
||||||
self,
|
self,
|
||||||
@ -175,11 +181,11 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
raise AttributeError("Translator needs either query or documents but not both.")
|
raise AttributeError("Translator needs either query or documents but not both.")
|
||||||
|
|
||||||
if not queries and not documents:
|
if not queries and not documents:
|
||||||
raise AttributeError("Translator needs query or documents to perform translation.")
|
raise AttributeError("Translator needs a query or documents to perform translation.")
|
||||||
|
|
||||||
|
translated = []
|
||||||
# Translate queries
|
# Translate queries
|
||||||
if queries:
|
if queries:
|
||||||
translated = []
|
|
||||||
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
|
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
|
||||||
cur_translation = self.translate(query=query)
|
cur_translation = self.translate(query=query)
|
||||||
translated.append(cur_translation)
|
translated.append(cur_translation)
|
||||||
@ -188,14 +194,13 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
elif documents:
|
elif documents:
|
||||||
# Single list of documents / answers
|
# Single list of documents / answers
|
||||||
if not isinstance(documents[0], list):
|
if not isinstance(documents[0], list):
|
||||||
translated = self.translate(documents=documents) # type: ignore
|
translated.append(self.translate(documents=documents)) # type: ignore
|
||||||
# Multiple lists of document / answer lists
|
# Multiple lists of document / answer lists
|
||||||
else:
|
else:
|
||||||
translated = []
|
|
||||||
for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"):
|
for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"):
|
||||||
if not isinstance(cur_list, list):
|
if not isinstance(cur_list, list):
|
||||||
raise HaystackError(
|
raise HaystackError(
|
||||||
f"cur_list was of type {type(cur_list)}, but expected a list of " f"Documents / Answers."
|
f"cur_list was of type {type(cur_list)}, but expected a list of Documents / Answers."
|
||||||
)
|
)
|
||||||
cur_translation = self.translate(documents=cur_list)
|
cur_translation = self.translate(documents=cur_list)
|
||||||
translated.append(cur_translation)
|
translated.append(cur_translation)
|
||||||
|
@ -5,6 +5,8 @@ import pytest
|
|||||||
EXPECTED_OUTPUT = "Ich lebe in Berlin"
|
EXPECTED_OUTPUT = "Ich lebe in Berlin"
|
||||||
INPUT = "I live in Berlin"
|
INPUT = "I live in Berlin"
|
||||||
|
|
||||||
|
DOCUMENT_INPUT = Document(content=INPUT)
|
||||||
|
|
||||||
|
|
||||||
def test_translator_with_query(en_to_de_translator):
|
def test_translator_with_query(en_to_de_translator):
|
||||||
assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
|
assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
|
||||||
@ -18,10 +20,22 @@ def test_translator_with_document(en_to_de_translator):
|
|||||||
assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT
|
assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT
|
||||||
|
|
||||||
|
|
||||||
|
def test_translator_with_document_preserves_input(en_to_de_translator):
|
||||||
|
original_document = Document(content=INPUT)
|
||||||
|
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
|
||||||
|
assert original_document.content == INPUT
|
||||||
|
|
||||||
|
|
||||||
def test_translator_with_dictionary(en_to_de_translator):
|
def test_translator_with_dictionary(en_to_de_translator):
|
||||||
assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT
|
assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT
|
||||||
|
|
||||||
|
|
||||||
|
def test_translator_with_dictionary_preserves_input(en_to_de_translator):
|
||||||
|
original_document = {"content": INPUT}
|
||||||
|
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
|
||||||
|
assert original_document["content"] == INPUT
|
||||||
|
|
||||||
|
|
||||||
def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
|
def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
|
||||||
assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT
|
assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user