mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
Add run_batch for standard pipelines (#2595)
* Add run_batch for standard pipelines * Update Documentation & Code Style * Fix mypy * Remove code duplication * Fix linter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
b2a2c10fae
commit
0395533a78
@ -1285,6 +1285,27 @@ You can select between:
|
||||
The default value is 'any'.
|
||||
In Question Answering, to enforce that the retrieved document is considered correct whenever the answer is correct, set `document_scope` to 'answer' or 'document_id_or_answer'.
|
||||
|
||||
<a id="standard_pipelines.BaseStandardPipeline.run_batch"></a>
|
||||
|
||||
#### BaseStandardPipeline.run\_batch
|
||||
|
||||
```python
|
||||
def run_batch(queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None)
|
||||
```
|
||||
|
||||
Run a batch of queries through the pipeline.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `queries`: List of query strings.
|
||||
- `params`: Parameters for the individual nodes of the pipeline. For instance,
|
||||
`params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}`
|
||||
- `debug`: Whether the pipeline should instruct nodes to collect debug information
|
||||
about their execution. By default these include the input parameters
|
||||
they received and the output they generated.
|
||||
All debug information can then be found in the dict returned
|
||||
by this method under the key "_debug"
|
||||
|
||||
<a id="standard_pipelines.ExtractiveQAPipeline"></a>
|
||||
|
||||
## ExtractiveQAPipeline
|
||||
@ -1454,6 +1475,27 @@ they received and the output they generated.
|
||||
All debug information can then be found in the dict returned
|
||||
by this method under the key "_debug"
|
||||
|
||||
<a id="standard_pipelines.SearchSummarizationPipeline.run_batch"></a>
|
||||
|
||||
#### SearchSummarizationPipeline.run\_batch
|
||||
|
||||
```python
|
||||
def run_batch(queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None)
|
||||
```
|
||||
|
||||
Run a batch of queries through the pipeline.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `queries`: List of query strings.
|
||||
- `params`: Parameters for the individual nodes of the pipeline. For instance,
|
||||
`params={"Retriever": {"top_k": 10}, "Summarizer": {"generate_single_summary": True}}`
|
||||
- `debug`: Whether the pipeline should instruct nodes to collect debug information
|
||||
about their execution. By default these include the input parameters
|
||||
they received and the output they generated.
|
||||
All debug information can then be found in the dict returned
|
||||
by this method under the key "_debug"
|
||||
|
||||
<a id="standard_pipelines.FAQPipeline"></a>
|
||||
|
||||
## FAQPipeline
|
||||
@ -1592,3 +1634,16 @@ def run(document_ids: List[str], top_k: int = 5)
|
||||
- `document_ids`: document ids
|
||||
- `top_k`: How many documents id to return against single document
|
||||
|
||||
<a id="standard_pipelines.MostSimilarDocumentsPipeline.run_batch"></a>
|
||||
|
||||
#### MostSimilarDocumentsPipeline.run\_batch
|
||||
|
||||
```python
|
||||
def run_batch(document_ids: List[str], top_k: int = 5)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `document_ids`: document ids
|
||||
- `top_k`: How many documents id to return against single document
|
||||
|
||||
|
||||
@ -2,15 +2,14 @@ import logging
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal # type: ignore
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from haystack.schema import Document, EvaluationResult, MultiLabel
|
||||
from haystack.nodes.answer_generator.base import BaseGenerator
|
||||
from haystack.nodes.other.docs2answers import Docs2Answers
|
||||
@ -277,6 +276,22 @@ class BaseStandardPipeline(ABC):
|
||||
answer_scope=answer_scope,
|
||||
)
|
||||
|
||||
def run_batch(self, queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None):
|
||||
"""
|
||||
Run a batch of queries through the pipeline.
|
||||
|
||||
:param queries: List of query strings.
|
||||
:param params: Parameters for the individual nodes of the pipeline. For instance,
|
||||
`params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}`
|
||||
:param debug: Whether the pipeline should instruct nodes to collect debug information
|
||||
about their execution. By default these include the input parameters
|
||||
they received and the output they generated.
|
||||
All debug information can then be found in the dict returned
|
||||
by this method under the key "_debug"
|
||||
"""
|
||||
output = self.pipeline.run_batch(queries=queries, params=params, debug=debug)
|
||||
return output
|
||||
|
||||
|
||||
class ExtractiveQAPipeline(BaseStandardPipeline):
|
||||
"""
|
||||
@ -415,6 +430,45 @@ class SearchSummarizationPipeline(BaseStandardPipeline):
|
||||
results = output
|
||||
return results
|
||||
|
||||
def run_batch(self, queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None):
|
||||
"""
|
||||
Run a batch of queries through the pipeline.
|
||||
|
||||
:param queries: List of query strings.
|
||||
:param params: Parameters for the individual nodes of the pipeline. For instance,
|
||||
`params={"Retriever": {"top_k": 10}, "Summarizer": {"generate_single_summary": True}}`
|
||||
:param debug: Whether the pipeline should instruct nodes to collect debug information
|
||||
about their execution. By default these include the input parameters
|
||||
they received and the output they generated.
|
||||
All debug information can then be found in the dict returned
|
||||
by this method under the key "_debug"
|
||||
"""
|
||||
output = self.pipeline.run_batch(queries=queries, params=params, debug=debug)
|
||||
|
||||
# Convert to answer format to allow "drop-in replacement" for other QA pipelines
|
||||
if self.return_in_answer_format:
|
||||
results: Dict = {"queries": queries, "answers": []}
|
||||
docs = deepcopy(output["documents"])
|
||||
for query, cur_docs in zip(queries, docs):
|
||||
cur_answers = []
|
||||
for doc in cur_docs:
|
||||
cur_answer = {
|
||||
"query": query,
|
||||
"answer": doc.content,
|
||||
"document_id": doc.id,
|
||||
"context": doc.meta.pop("context"),
|
||||
"score": None,
|
||||
"offset_start": None,
|
||||
"offset_end": None,
|
||||
"meta": doc.meta,
|
||||
}
|
||||
cur_answers.append(cur_answer)
|
||||
|
||||
results["answers"].append(cur_answers)
|
||||
else:
|
||||
results = output
|
||||
return results
|
||||
|
||||
|
||||
class FAQPipeline(BaseStandardPipeline):
|
||||
"""
|
||||
@ -484,6 +538,10 @@ class TranslationWrapperPipeline(BaseStandardPipeline):
|
||||
output = self.pipeline.run(**kwargs)
|
||||
return output
|
||||
|
||||
def run_batch(self, **kwargs):
|
||||
output = self.pipeline.run_batch(**kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class QuestionGenerationPipeline(BaseStandardPipeline):
|
||||
"""
|
||||
@ -499,6 +557,15 @@ class QuestionGenerationPipeline(BaseStandardPipeline):
|
||||
output = self.pipeline.run(documents=documents, params=params, debug=debug)
|
||||
return output
|
||||
|
||||
def run_batch( # type: ignore
|
||||
self,
|
||||
documents: Union[List[Document], List[List[Document]]],
|
||||
params: Optional[dict] = None,
|
||||
debug: Optional[bool] = None,
|
||||
):
|
||||
output = self.pipeline.run_batch(documents=documents, params=params, debug=debug)
|
||||
return output
|
||||
|
||||
|
||||
class RetrieverQuestionGenerationPipeline(BaseStandardPipeline):
|
||||
"""
|
||||
@ -580,3 +647,10 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
|
||||
|
||||
self.document_store.return_embedding = False # type: ignore
|
||||
return similar_documents
|
||||
|
||||
def run_batch(self, document_ids: List[str], top_k: int = 5): # type: ignore
|
||||
"""
|
||||
:param document_ids: document ids
|
||||
:param top_k: How many documents id to return against single document
|
||||
"""
|
||||
return self.run(document_ids=document_ids, top_k=top_k)
|
||||
|
||||
@ -52,6 +52,28 @@ def test_faq_pipeline(retriever, document_store):
|
||||
assert len(output["answers"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
|
||||
def test_faq_pipeline_batch(retriever, document_store):
|
||||
documents = [
|
||||
{"content": "How to test module-1?", "meta": {"source": "wiki1", "answer": "Using tests for module-1"}},
|
||||
{"content": "How to test module-2?", "meta": {"source": "wiki2", "answer": "Using tests for module-2"}},
|
||||
{"content": "How to test module-3?", "meta": {"source": "wiki3", "answer": "Using tests for module-3"}},
|
||||
{"content": "How to test module-4?", "meta": {"source": "wiki4", "answer": "Using tests for module-4"}},
|
||||
{"content": "How to test module-5?", "meta": {"source": "wiki5", "answer": "Using tests for module-5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
pipeline = FAQPipeline(retriever=retriever)
|
||||
|
||||
output = pipeline.run_batch(queries=["How to test this?", "How to test this?"], params={"Retriever": {"top_k": 3}})
|
||||
assert len(output["answers"]) == 2 # 2 queries
|
||||
assert len(output["answers"][0]) == 3 # 3 answers per query
|
||||
assert output["queries"][0].startswith("How to")
|
||||
assert output["answers"][0][0].answer.startswith("Using tests")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
@pytest.mark.parametrize(
|
||||
"document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True
|
||||
@ -77,6 +99,26 @@ def test_document_search_pipeline(retriever, document_store):
|
||||
assert len(output["documents"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_document_search_pipeline_batch(retriever, document_store):
|
||||
documents = [
|
||||
{"content": "Sample text for document-1", "meta": {"source": "wiki1"}},
|
||||
{"content": "Sample text for document-2", "meta": {"source": "wiki2"}},
|
||||
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
|
||||
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
|
||||
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever)
|
||||
output = pipeline.run_batch(queries=["How to test this?", "How to test this?"], params={"top_k": 4})
|
||||
assert len(output["documents"]) == 2 # 2 queries
|
||||
assert len(output["documents"][0]) == 4 # 4 docs per query
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch", "dpr", "embedding"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@ -158,6 +200,35 @@ def test_most_similar_documents_pipeline(retriever, document_store):
|
||||
assert isinstance(document.content, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
|
||||
def test_most_similar_documents_pipeline_batch(retriever, document_store):
|
||||
documents = [
|
||||
{"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}},
|
||||
{"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}},
|
||||
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
|
||||
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
|
||||
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
|
||||
]
|
||||
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
docs_id: list = ["a", "b"]
|
||||
pipeline = MostSimilarDocumentsPipeline(document_store=document_store)
|
||||
list_of_documents = pipeline.run_batch(document_ids=docs_id)
|
||||
|
||||
assert len(list_of_documents[0]) > 1
|
||||
assert isinstance(list_of_documents, list)
|
||||
assert len(list_of_documents) == len(docs_id)
|
||||
|
||||
for another_list in list_of_documents:
|
||||
assert isinstance(another_list, list)
|
||||
for document in another_list:
|
||||
assert isinstance(document, Document)
|
||||
assert isinstance(document.id, str)
|
||||
assert isinstance(document.content, str)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_join_merge_no_weights(document_store_dot_product_with_docs):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user