mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-01 20:33:48 +00:00
bug: Make TranslationWrapperPipeline
work with QuestionAnswerGenerationPipeline
(#3034)
* Overwrite output_translator's run method with run_batch * Fix mypy * Revert change * Overwrite run method only with QuestionAnswerGenerationPipeline
This commit is contained in:
parent
1b422ab657
commit
3a849d6c07
@ -117,7 +117,7 @@ Run the actual translation. You can supply a query or a list of documents. Whate
|
|||||||
#### TransformersTranslator.translate\_batch
|
#### TransformersTranslator.translate\_batch
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]
|
def translate_batch(queries: Optional[List[str]] = None, documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None, batch_size: Optional[int] = None) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]
|
||||||
```
|
```
|
||||||
|
|
||||||
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
|
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
|
||||||
|
@ -32,7 +32,7 @@ class BaseTranslator(BaseComponent):
|
|||||||
queries: Optional[List[str]] = None,
|
queries: Optional[List[str]] = None,
|
||||||
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]:
|
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run( # type: ignore
|
def run( # type: ignore
|
||||||
|
@ -161,7 +161,7 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
queries: Optional[List[str]] = None,
|
queries: Optional[List[str]] = None,
|
||||||
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
documents: Optional[Union[List[Document], List[Answer], List[List[Document]], List[List[Answer]]]] = None,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
) -> Union[str, List[str], List[Document], List[Answer], List[List[Document]], List[List[Answer]]]:
|
) -> List[Union[str, List[Document], List[Answer], List[str], List[Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
|
Run the actual translation. You can supply a single query, a list of queries or a list (of lists) of documents.
|
||||||
|
|
||||||
@ -181,7 +181,7 @@ class TransformersTranslator(BaseTranslator):
|
|||||||
if queries:
|
if queries:
|
||||||
translated = []
|
translated = []
|
||||||
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
|
for query in tqdm(queries, disable=not self.progress_bar, desc="Translating"):
|
||||||
cur_translation = self.run(query=query)
|
cur_translation = self.translate(query=query)
|
||||||
translated.append(cur_translation)
|
translated.append(cur_translation)
|
||||||
|
|
||||||
# Translate docs / answers
|
# Translate docs / answers
|
||||||
|
@ -523,6 +523,11 @@ class TranslationWrapperPipeline(BaseStandardPipeline):
|
|||||||
|
|
||||||
self.pipeline = Pipeline()
|
self.pipeline = Pipeline()
|
||||||
self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])
|
self.pipeline.add_node(component=input_translator, name="InputTranslator", inputs=["Query"])
|
||||||
|
# Make use of run_batch instead of run for output_translator if pipeline is a QuestionAnswerGenerationPipeline,
|
||||||
|
# as the reader's run method is overwritten by its run_batch method, which is incompatible with the translator's
|
||||||
|
# run method.
|
||||||
|
if isinstance(pipeline, QuestionAnswerGenerationPipeline):
|
||||||
|
setattr(output_translator, "run", output_translator.run_batch)
|
||||||
|
|
||||||
graph = pipeline.pipeline.graph
|
graph = pipeline.pipeline.graph
|
||||||
previous_node_name = ["InputTranslator"]
|
previous_node_name = ["InputTranslator"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user