From 3a849d6c073068d15fe21cdb92eaaf35be47baea Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Mon, 15 Aug 2022 10:05:34 +0200 Subject: [PATCH] bug: Make `TranslationWrapperPipeline` work with `QuestionAnswerGenerationPipeline` (#3034) * Overwrite output_translator's run method with run_batch * Fix mypy * Revert change * Overwrite run method only with QuestionAnswerGenerationPipeline --- docs/_src/api/api/translator.md | 2 +- haystack/nodes/translator/base.py | 2 +- haystack/nodes/translator/transformers.py | 4 ++-- haystack/pipelines/standard_pipelines.py | 5 +++++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/_src/api/api/translator.md b/docs/_src/api/api/translator.md index ee934f5f4..159a6b987 100644 --- a/docs/_src/api/api/translator.md +++ b/docs/_src/api/api/translator.md @@ -117,7 +117,7 @@ Run the actual translation. You can supply a query or a list of documents. Whate #### TransformersTranslator.translate\_batch ```python -def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]] +def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]] ``` Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents. diff --git a/haystack/nodes/translator/base.py b/haystack/nodes/translator/base.py index 5b3fa754b..41d0eecb8 100644 --- a/haystack/nodes/translator/base.py +++ b/haystack/nodes/translator/base.py @@ -32,7 +32,7 @@ class BaseTranslator(BaseComponent): queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None, - ) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]: + ) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]: pass def run( # type: ignore diff --git a/haystack/nodes/translator/transformers.py b/haystack/nodes/translator/transformers.py index b125baf70..217599496 100644 --- a/haystack/nodes/translator/transformers.py +++ b/haystack/nodes/translator/transformers.py @@ -161,7 +161,7 @@ class TransformersTranslator(BaseTranslator): queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None, - ) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]: + ) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]: """ Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents. @@ -181,7 +181,7 @@ class TransformersTranslator(BaseTranslator): if queries: translated = [] for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"): - cur_translation = self.run(query=query) + cur_translation = self.translate(query=query) translated.append(cur_translation) # Translate docs / answers diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py index 0fcdb1839..4c177a7dd 100644 --- a/haystack/pipelines/standard_pipelines.py +++ b/haystack/pipelines/standard_pipelines.py @@ -523,6 +523,11 @@ class TranslationWrapperPipeline(BaseStandardPipeline): self.pipeline = Pipeline() self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"]) + # Make use of run_batch instead of run for output_translator if pipeline is a QuestionAnswerGenerationPipeline, + # as the reader's run method is overwritten by its run_batch method, which is incompatible with the translator's + # run method. + if isinstance(pipeline, QuestionAnswerGenerationPipeline): + setattr(output_translator, "run", output_translator.run_batch) graph = pipeline.pipeline.graph previous_node_name = ["InputTranslator"]