Add run_batch for standard pipelines (#2595)

* Add run_batch for standard pipelines

* Update Documentation & Code Style

* Fix mypy

* Remove code duplication

* Fix linter

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
bogdankostic 2022-05-27 10:42:48 +02:00 committed by GitHub
parent b2a2c10fae
commit 0395533a78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 203 additions and 3 deletions

View File

@ -1285,6 +1285,27 @@ You can select between:
The default value is 'any'.
In Question Answering, to enforce that the retrieved document is considered correct whenever the answer is correct, set `document_scope` to 'answer' or 'document_id_or_answer'.
<a id="standard_pipelines.BaseStandardPipeline.run_batch"></a>
#### BaseStandardPipeline.run\_batch
```python
def run_batch(queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None)
```
Run a batch of queries through the pipeline.
**Arguments**:
- `queries`: List of query strings.
- `params`: Parameters for the individual nodes of the pipeline. For instance,
`params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}`
- `debug`: Whether the pipeline should instruct nodes to collect debug information
about their execution. By default these include the input parameters
they received and the output they generated.
All debug information can then be found in the dict returned
by this method under the key "_debug"
<a id="standard_pipelines.ExtractiveQAPipeline"></a>
## ExtractiveQAPipeline
@ -1454,6 +1475,27 @@ they received and the output they generated.
All debug information can then be found in the dict returned
by this method under the key "_debug"
<a id="standard_pipelines.SearchSummarizationPipeline.run_batch"></a>
#### SearchSummarizationPipeline.run\_batch
```python
def run_batch(queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None)
```
Run a batch of queries through the pipeline.
**Arguments**:
- `queries`: List of query strings.
- `params`: Parameters for the individual nodes of the pipeline. For instance,
`params={"Retriever": {"top_k": 10}, "Summarizer": {"generate_single_summary": True}}`
- `debug`: Whether the pipeline should instruct nodes to collect debug information
about their execution. By default these include the input parameters
they received and the output they generated.
All debug information can then be found in the dict returned
by this method under the key "_debug"
<a id="standard_pipelines.FAQPipeline"></a>
## FAQPipeline
@ -1592,3 +1634,16 @@ def run(document_ids: List[str], top_k: int = 5)
- `document_ids`: document ids
- `top_k`: How many documents id to return against single document
<a id="standard_pipelines.MostSimilarDocumentsPipeline.run_batch"></a>
#### MostSimilarDocumentsPipeline.run\_batch
```python
def run_batch(document_ids: List[str], top_k: int = 5)
```
**Arguments**:
- `document_ids`: document ids
- `top_k`: How many documents id to return against single document

View File

@ -2,15 +2,14 @@ import logging
from abc import ABC
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict, Any
from functools import wraps
from typing import List, Optional, Dict, Any, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
from functools import wraps
from haystack.schema import Document, EvaluationResult, MultiLabel
from haystack.nodes.answer_generator.base import BaseGenerator
from haystack.nodes.other.docs2answers import Docs2Answers
@ -277,6 +276,22 @@ class BaseStandardPipeline(ABC):
answer_scope=answer_scope,
)
def run_batch(self, queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None):
"""
Run a batch of queries through the pipeline.
:param queries: List of query strings.
:param params: Parameters for the individual nodes of the pipeline. For instance,
`params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}`
:param debug: Whether the pipeline should instruct nodes to collect debug information
about their execution. By default these include the input parameters
they received and the output they generated.
All debug information can then be found in the dict returned
by this method under the key "_debug"
"""
output = self.pipeline.run_batch(queries=queries, params=params, debug=debug)
return output
class ExtractiveQAPipeline(BaseStandardPipeline):
"""
@ -415,6 +430,45 @@ class SearchSummarizationPipeline(BaseStandardPipeline):
results = output
return results
def run_batch(self, queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None):
"""
Run a batch of queries through the pipeline.
:param queries: List of query strings.
:param params: Parameters for the individual nodes of the pipeline. For instance,
`params={"Retriever": {"top_k": 10}, "Summarizer": {"generate_single_summary": True}}`
:param debug: Whether the pipeline should instruct nodes to collect debug information
about their execution. By default these include the input parameters
they received and the output they generated.
All debug information can then be found in the dict returned
by this method under the key "_debug"
"""
output = self.pipeline.run_batch(queries=queries, params=params, debug=debug)
# Convert to answer format to allow "drop-in replacement" for other QA pipelines
if self.return_in_answer_format:
results: Dict = {"queries": queries, "answers": []}
docs = deepcopy(output["documents"])
for query, cur_docs in zip(queries, docs):
cur_answers = []
for doc in cur_docs:
cur_answer = {
"query": query,
"answer": doc.content,
"document_id": doc.id,
"context": doc.meta.pop("context"),
"score": None,
"offset_start": None,
"offset_end": None,
"meta": doc.meta,
}
cur_answers.append(cur_answer)
results["answers"].append(cur_answers)
else:
results = output
return results
class FAQPipeline(BaseStandardPipeline):
"""
@ -484,6 +538,10 @@ class TranslationWrapperPipeline(BaseStandardPipeline):
output = self.pipeline.run(**kwargs)
return output
def run_batch(self, **kwargs):
output = self.pipeline.run_batch(**kwargs)
return output
class QuestionGenerationPipeline(BaseStandardPipeline):
"""
@ -499,6 +557,15 @@ class QuestionGenerationPipeline(BaseStandardPipeline):
output = self.pipeline.run(documents=documents, params=params, debug=debug)
return output
def run_batch( # type: ignore
self,
documents: Union[List[Document], List[List[Document]]],
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
output = self.pipeline.run_batch(documents=documents, params=params, debug=debug)
return output
class RetrieverQuestionGenerationPipeline(BaseStandardPipeline):
"""
@ -580,3 +647,10 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline):
self.document_store.return_embedding = False # type: ignore
return similar_documents
def run_batch(self, document_ids: List[str], top_k: int = 5): # type: ignore
"""
:param document_ids: document ids
:param top_k: How many documents id to return against single document
"""
return self.run(document_ids=document_ids, top_k=top_k)

View File

@ -52,6 +52,28 @@ def test_faq_pipeline(retriever, document_store):
assert len(output["answers"]) == 1
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
def test_faq_pipeline_batch(retriever, document_store):
documents = [
{"content": "How to test module-1?", "meta": {"source": "wiki1", "answer": "Using tests for module-1"}},
{"content": "How to test module-2?", "meta": {"source": "wiki2", "answer": "Using tests for module-2"}},
{"content": "How to test module-3?", "meta": {"source": "wiki3", "answer": "Using tests for module-3"}},
{"content": "How to test module-4?", "meta": {"source": "wiki4", "answer": "Using tests for module-4"}},
{"content": "How to test module-5?", "meta": {"source": "wiki5", "answer": "Using tests for module-5"}},
]
document_store.write_documents(documents)
document_store.update_embeddings(retriever)
pipeline = FAQPipeline(retriever=retriever)
output = pipeline.run_batch(queries=["How to test this?", "How to test this?"], params={"Retriever": {"top_k": 3}})
assert len(output["answers"]) == 2 # 2 queries
assert len(output["answers"][0]) == 3 # 3 answers per query
assert output["queries"][0].startswith("How to")
assert output["answers"][0][0].answer.startswith("Using tests")
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize(
"document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True
@ -77,6 +99,26 @@ def test_document_search_pipeline(retriever, document_store):
assert len(output["documents"]) == 1
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_document_search_pipeline_batch(retriever, document_store):
documents = [
{"content": "Sample text for document-1", "meta": {"source": "wiki1"}},
{"content": "Sample text for document-2", "meta": {"source": "wiki2"}},
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
]
document_store.write_documents(documents)
document_store.update_embeddings(retriever)
pipeline = DocumentSearchPipeline(retriever=retriever)
output = pipeline.run_batch(queries=["How to test this?", "How to test this?"], params={"top_k": 4})
assert len(output["documents"]) == 2 # 2 queries
assert len(output["documents"][0]) == 4 # 4 docs per query
@pytest.mark.integration
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch", "dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
@ -158,6 +200,35 @@ def test_most_similar_documents_pipeline(retriever, document_store):
assert isinstance(document.content, str)
@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True)
def test_most_similar_documents_pipeline_batch(retriever, document_store):
documents = [
{"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}},
{"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}},
{"content": "Sample text for document-3", "meta": {"source": "wiki3"}},
{"content": "Sample text for document-4", "meta": {"source": "wiki4"}},
{"content": "Sample text for document-5", "meta": {"source": "wiki5"}},
]
document_store.write_documents(documents)
document_store.update_embeddings(retriever)
docs_id: list = ["a", "b"]
pipeline = MostSimilarDocumentsPipeline(document_store=document_store)
list_of_documents = pipeline.run_batch(document_ids=docs_id)
assert len(list_of_documents[0]) > 1
assert isinstance(list_of_documents, list)
assert len(list_of_documents) == len(docs_id)
for another_list in list_of_documents:
assert isinstance(another_list, list)
for document in another_list:
assert isinstance(document, Document)
assert isinstance(document.id, str)
assert isinstance(document.content, str)
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True)
def test_join_merge_no_weights(document_store_dot_product_with_docs):