bug: Make TranslationWrapperPipeline work with QuestionAnswerGenerationPipeline (#3034)

* Overwrite output_translator's run method with run_batch

* Fix mypy

* Revert change

* Overwrite run method only with QuestionAnswerGenerationPipeline
This commit is contained in:
bogdankostic 2022-08-15 10:05:34 +02:00 committed by GitHub
parent 1b422ab657
commit 3a849d6c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 4 deletions

View File

@ -117,7 +117,7 @@ Run the actual translation. You can supply a query or a list of documents. Whate
#### TransformersTranslator.translate\_batch
```python
def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]
def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]
```
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.

View File

@ -32,7 +32,7 @@ class BaseTranslator(BaseComponent):
queries: Optional[List[str]] = None,
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
batch_size: Optional[int] = None,
) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]:
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
pass
def run( # type: ignore

View File

@ -161,7 +161,7 @@ class TransformersTranslator(BaseTranslator):
queries: Optional[List[str]] = None,
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
batch_size: Optional[int] = None,
) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]:
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
"""
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
@ -181,7 +181,7 @@ class TransformersTranslator(BaseTranslator):
if queries:
translated = []
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
cur_translation = self.run(query=query)
cur_translation = self.translate(query=query)
translated.append(cur_translation)
# Translate docs / answers

View File

@ -523,6 +523,11 @@ class TranslationWrapperPipeline(BaseStandardPipeline):
self.pipeline = Pipeline()
self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])
# Make use of run_batch instead of run for output_translator if pipeline is a QuestionAnswerGenerationPipeline,
# as the reader's run method is overwritten by its run_batch method, which is incompatible with the translator's
# run method.
if isinstance(pipeline, QuestionAnswerGenerationPipeline):
setattr(output_translator, "run", output_translator.run_batch)
graph = pipeline.pipeline.graph
previous_node_name = ["InputTranslator"]