mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-14 09:28:56 +00:00
fix(translator): write translated text to output documents, while keeping input untouched (#3077)
* Set translated text on a copy of original document * Return new translated list * Manually generated docs TODO: check pre-commit * Hook generated file * Rename variables for better maintenance * fix(translator): prevent inputs from being changed * fix: manual update translator docs * style(translator): explicit type declaration on List * docs(translator): re-run pre-commit hook * style(translator): ignore mypy wrong type check * docs(translator): re-run pre-commit hook
This commit is contained in:
parent
bc6f71b5ba
commit
d5e36ce6b4
@ -100,7 +100,7 @@ tokenizer.
|
||||
#### TransformersTranslator.translate
|
||||
|
||||
```python
|
||||
def translate(results: List[Dict[str, Any]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
|
||||
def translate(results: Optional[List[Dict[str, Any]]] = None, query: Optional[str] = None, documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None, dict_key: Optional[str] = None) -> Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]
|
||||
```
|
||||
|
||||
Run the actual translation. You can supply a query or a list of documents. Whatever is supplied will be translated.
|
||||
|
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document, Answer
|
||||
@ -77,7 +78,7 @@ class TransformersTranslator(BaseTranslator):
|
||||
|
||||
def translate(
|
||||
self,
|
||||
results: List[Dict[str, Any]] = None,
|
||||
results: Optional[List[Dict[str, Any]]] = None,
|
||||
query: Optional[str] = None,
|
||||
documents: Optional[Union[List[Document], List[Answer], List[str], List[Dict[str, Any]]]] = None,
|
||||
dict_key: Optional[str] = None,
|
||||
@ -95,10 +96,10 @@ class TransformersTranslator(BaseTranslator):
|
||||
queries_for_translator = [result["query"] for result in results]
|
||||
answers_for_translator = [result["answers"][0].answer for result in results]
|
||||
if not query and not documents and results is None:
|
||||
raise AttributeError("Translator needs query or documents to perform translation.")
|
||||
raise AttributeError("Translator needs a query or documents to perform translation.")
|
||||
|
||||
if query and documents:
|
||||
raise AttributeError("Translator needs either query or documents but not both.")
|
||||
raise AttributeError("Translator needs either a query or documents but not both.")
|
||||
|
||||
if documents and len(documents) == 0:
|
||||
logger.warning("Empty documents list is passed")
|
||||
@ -144,17 +145,22 @@ class TransformersTranslator(BaseTranslator):
|
||||
if isinstance(documents, list) and isinstance(documents[0], str):
|
||||
return [translated_text for translated_text in translated_texts]
|
||||
|
||||
translated_documents: Union[
|
||||
List[Document], List[Answer], List[str], List[Dict[str, Any]]
|
||||
] = [] # type: ignore
|
||||
for translated_text, doc in zip(translated_texts, documents):
|
||||
if isinstance(doc, Document):
|
||||
doc.content = translated_text
|
||||
elif isinstance(doc, Answer):
|
||||
doc.answer = translated_text
|
||||
translated_document = deepcopy(doc)
|
||||
if isinstance(translated_document, Document):
|
||||
translated_document.content = translated_text
|
||||
elif isinstance(translated_document, Answer):
|
||||
translated_document.answer = translated_text
|
||||
else:
|
||||
doc[dict_key] = translated_text # type: ignore
|
||||
translated_document[dict_key] = translated_text # type: ignore
|
||||
translated_documents.append(translated_document) # type: ignore
|
||||
|
||||
return documents
|
||||
return translated_documents
|
||||
|
||||
raise AttributeError("Translator need query or documents to perform translation")
|
||||
raise AttributeError("Translator needs a query or documents to perform translation")
|
||||
|
||||
def translate_batch(
|
||||
self,
|
||||
@ -175,11 +181,11 @@ class TransformersTranslator(BaseTranslator):
|
||||
raise AttributeError("Translator needs either query or documents but not both.")
|
||||
|
||||
if not queries and not documents:
|
||||
raise AttributeError("Translator needs query or documents to perform translation.")
|
||||
raise AttributeError("Translator needs a query or documents to perform translation.")
|
||||
|
||||
translated = []
|
||||
# Translate queries
|
||||
if queries:
|
||||
translated = []
|
||||
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
|
||||
cur_translation = self.translate(query=query)
|
||||
translated.append(cur_translation)
|
||||
@ -188,14 +194,13 @@ class TransformersTranslator(BaseTranslator):
|
||||
elif documents:
|
||||
# Single list of documents / answers
|
||||
if not isinstance(documents[0], list):
|
||||
translated = self.translate(documents=documents) # type: ignore
|
||||
translated.append(self.translate(documents=documents)) # type: ignore
|
||||
# Multiple lists of document / answer lists
|
||||
else:
|
||||
translated = []
|
||||
for cur_list in tqdm(documents, disable=not self.progress_bar, desc="Translating"):
|
||||
if not isinstance(cur_list, list):
|
||||
raise HaystackError(
|
||||
f"cur_list was of type {type(cur_list)}, but expected a list of " f"Documents / Answers."
|
||||
f"cur_list was of type {type(cur_list)}, but expected a list of Documents / Answers."
|
||||
)
|
||||
cur_translation = self.translate(documents=cur_list)
|
||||
translated.append(cur_translation)
|
||||
|
@ -5,6 +5,8 @@ import pytest
|
||||
EXPECTED_OUTPUT = "Ich lebe in Berlin"
|
||||
INPUT = "I live in Berlin"
|
||||
|
||||
DOCUMENT_INPUT = Document(content=INPUT)
|
||||
|
||||
|
||||
def test_translator_with_query(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(query=INPUT) == EXPECTED_OUTPUT
|
||||
@ -18,10 +20,22 @@ def test_translator_with_document(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[Document(content=INPUT)])[0].content == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_document_preserves_input(en_to_de_translator):
|
||||
original_document = Document(content=INPUT)
|
||||
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
|
||||
assert original_document.content == INPUT
|
||||
|
||||
|
||||
def test_translator_with_dictionary(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[{"content": INPUT}])[0]["content"] == EXPECTED_OUTPUT
|
||||
|
||||
|
||||
def test_translator_with_dictionary_preserves_input(en_to_de_translator):
|
||||
original_document = {"content": INPUT}
|
||||
en_to_de_translator.translate(documents=[original_document])[0] # pylint: disable=expression-not-assigned
|
||||
assert original_document["content"] == INPUT
|
||||
|
||||
|
||||
def test_translator_with_dictionary_with_dict_key(en_to_de_translator):
|
||||
assert en_to_de_translator.translate(documents=[{"key": INPUT}], dict_key="key")[0]["key"] == EXPECTED_OUTPUT
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user