mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-23 22:09:00 +00:00
bug: Make TranslationWrapperPipeline
work with QuestionAnswerGenerationPipeline
(#3034)
* Overwrite output_translator's run method with run_batch * Fix mypy * Revert change * Overwrite run method only with QuestionAnswerGenerationPipeline
This commit is contained in:
parent
1b422ab657
commit
3a849d6c07
@ -117,7 +117,7 @@ Run the actual translation. You can supply a query or a list of documents. Whate
|
||||
#### TransformersTranslator.translate\_batch
|
||||
|
||||
```python
|
||||
def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]
|
||||
def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]
|
||||
```
|
||||
|
||||
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
|
||||
|
@ -32,7 +32,7 @@ class BaseTranslator(BaseComponent):
|
||||
queries: Optional[List[str]] = None,
|
||||
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]:
|
||||
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
|
||||
pass
|
||||
|
||||
def run( # type: ignore
|
||||
|
@ -161,7 +161,7 @@ class TransformersTranslator(BaseTranslator):
|
||||
queries: Optional[List[str]] = None,
|
||||
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]:
|
||||
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
|
||||
|
||||
@ -181,7 +181,7 @@ class TransformersTranslator(BaseTranslator):
|
||||
if queries:
|
||||
translated = []
|
||||
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
|
||||
cur_translation = self.run(query=query)
|
||||
cur_translation = self.translate(query=query)
|
||||
translated.append(cur_translation)
|
||||
|
||||
# Translate docs / answers
|
||||
|
@ -523,6 +523,11 @@ class TranslationWrapperPipeline(BaseStandardPipeline):
|
||||
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])
|
||||
# Make use of run_batch instead of run for output_translator if pipeline is a QuestionAnswerGenerationPipeline,
|
||||
# as the reader's run method is overwritten by its run_batch method, which is incompatible with the translator's
|
||||
# run method.
|
||||
if isinstance(pipeline, QuestionAnswerGenerationPipeline):
|
||||
setattr(output_translator, "run", output_translator.run_batch)
|
||||
|
||||
graph = pipeline.pipeline.graph
|
||||
previous_node_name = ["InputTranslator"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user