From 0395533a786cc63bf2f5180ee7d3dc3eefebdd59 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Fri, 27 May 2022 10:42:48 +0200 Subject: [PATCH] Add `run_batch` for standard pipelines (#2595) * Add run_batch for standard pipelines * Update Documentation & Code Style * Fix mypy * Remove code duplication * Fix linter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/_src/api/api/pipelines.md | 55 ++++++++++++++++ haystack/pipelines/standard_pipelines.py | 80 ++++++++++++++++++++++- test/pipelines/test_standard_pipelines.py | 71 ++++++++++++++++++++ 3 files changed, 203 insertions(+), 3 deletions(-) diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md index 3243e195f..da0b500bc 100644 --- a/docs/_src/api/api/pipelines.md +++ b/docs/_src/api/api/pipelines.md @@ -1285,6 +1285,27 @@ You can select between: The default value is 'any'. In Question Answering, to enforce that the retrieved document is considered correct whenever the answer is correct, set `document_scope` to 'answer' or 'document_id_or_answer'. + + +#### BaseStandardPipeline.run\_batch + +```python +def run_batch(queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None) +``` + +Run a batch of queries through the pipeline. + +**Arguments**: + +- `queries`: List of query strings. +- `params`: Parameters for the individual nodes of the pipeline. For instance, +`params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}` +- `debug`: Whether the pipeline should instruct nodes to collect debug information +about their execution. By default these include the input parameters +they received and the output they generated. +All debug information can then be found in the dict returned +by this method under the key "_debug" + ## ExtractiveQAPipeline @@ -1454,6 +1475,27 @@ they received and the output they generated. All debug information can then be found in the dict returned by this method under the key "_debug" + + +#### SearchSummarizationPipeline.run\_batch + +```python +def run_batch(queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None) +``` + +Run a batch of queries through the pipeline. + +**Arguments**: + +- `queries`: List of query strings. +- `params`: Parameters for the individual nodes of the pipeline. For instance, +`params={"Retriever": {"top_k": 10}, "Summarizer": {"generate_single_summary": True}}` +- `debug`: Whether the pipeline should instruct nodes to collect debug information +about their execution. By default these include the input parameters +they received and the output they generated. +All debug information can then be found in the dict returned +by this method under the key "_debug" + ## FAQPipeline @@ -1592,3 +1634,16 @@ def run(document_ids: List[str], top_k: int = 5) - `document_ids`: document ids - `top_k`: How many documents id to return against single document + + +#### MostSimilarDocumentsPipeline.run\_batch + +```python +def run_batch(document_ids: List[str], top_k: int = 5) +``` + +**Arguments**: + +- `document_ids`: document ids +- `top_k`: How many documents id to return against single document + diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py index ced2955af..585f9a57a 100644 --- a/haystack/pipelines/standard_pipelines.py +++ b/haystack/pipelines/standard_pipelines.py @@ -2,15 +2,14 @@ import logging from abc import ABC from copy import deepcopy from pathlib import Path -from typing import List, Optional, Dict, Any +from functools import wraps +from typing import List, Optional, Dict, Any, Union try: from typing import Literal except ImportError: from typing_extensions import Literal # type: ignore -from functools import wraps - from haystack.schema import Document, EvaluationResult, MultiLabel from haystack.nodes.answer_generator.base import BaseGenerator from haystack.nodes.other.docs2answers import Docs2Answers @@ -277,6 +276,22 @@ class BaseStandardPipeline(ABC): answer_scope=answer_scope, ) + def run_batch(self, queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None): + """ + Run a batch of queries through the pipeline. + + :param queries: List of query strings. + :param params: Parameters for the individual nodes of the pipeline. For instance, + `params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}` + :param debug: Whether the pipeline should instruct nodes to collect debug information + about their execution. By default these include the input parameters + they received and the output they generated. + All debug information can then be found in the dict returned + by this method under the key "_debug" + """ + output = self.pipeline.run_batch(queries=queries, params=params, debug=debug) + return output + class ExtractiveQAPipeline(BaseStandardPipeline): """ @@ -415,6 +430,45 @@ class SearchSummarizationPipeline(BaseStandardPipeline): results = output return results + def run_batch(self, queries: List[str], params: Optional[dict] = None, debug: Optional[bool] = None): + """ + Run a batch of queries through the pipeline. + + :param queries: List of query strings. + :param params: Parameters for the individual nodes of the pipeline. For instance, + `params={"Retriever": {"top_k": 10}, "Summarizer": {"generate_single_summary": True}}` + :param debug: Whether the pipeline should instruct nodes to collect debug information + about their execution. By default these include the input parameters + they received and the output they generated. + All debug information can then be found in the dict returned + by this method under the key "_debug" + """ + output = self.pipeline.run_batch(queries=queries, params=params, debug=debug) + + # Convert to answer format to allow "drop-in replacement" for other QA pipelines + if self.return_in_answer_format: + results: Dict = {"queries": queries, "answers": []} + docs = deepcopy(output["documents"]) + for query, cur_docs in zip(queries, docs): + cur_answers = [] + for doc in cur_docs: + cur_answer = { + "query": query, + "answer": doc.content, + "document_id": doc.id, + "context": doc.meta.pop("context"), + "score": None, + "offset_start": None, + "offset_end": None, + "meta": doc.meta, + } + cur_answers.append(cur_answer) + + results["answers"].append(cur_answers) + else: + results = output + return results + class FAQPipeline(BaseStandardPipeline): """ @@ -484,6 +538,10 @@ class TranslationWrapperPipeline(BaseStandardPipeline): output = self.pipeline.run(**kwargs) return output + def run_batch(self, **kwargs): + output = self.pipeline.run_batch(**kwargs) + return output + class QuestionGenerationPipeline(BaseStandardPipeline): """ @@ -499,6 +557,15 @@ class QuestionGenerationPipeline(BaseStandardPipeline): output = self.pipeline.run(documents=documents, params=params, debug=debug) return output + def run_batch( # type: ignore + self, + documents: Union[List[Document], List[List[Document]]], + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + output = self.pipeline.run_batch(documents=documents, params=params, debug=debug) + return output + class RetrieverQuestionGenerationPipeline(BaseStandardPipeline): """ @@ -580,3 +647,10 @@ class MostSimilarDocumentsPipeline(BaseStandardPipeline): self.document_store.return_embedding = False # type: ignore return similar_documents + + def run_batch(self, document_ids: List[str], top_k: int = 5): # type: ignore + """ + :param document_ids: document ids + :param top_k: How many documents id to return against single document + """ + return self.run(document_ids=document_ids, top_k=top_k) diff --git a/test/pipelines/test_standard_pipelines.py b/test/pipelines/test_standard_pipelines.py index e51e733e9..57e59665a 100644 --- a/test/pipelines/test_standard_pipelines.py +++ b/test/pipelines/test_standard_pipelines.py @@ -52,6 +52,28 @@ def test_faq_pipeline(retriever, document_store): assert len(output["answers"]) == 1 +@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True) +def test_faq_pipeline_batch(retriever, document_store): + documents = [ + {"content": "How to test module-1?", "meta": {"source": "wiki1", "answer": "Using tests for module-1"}}, + {"content": "How to test module-2?", "meta": {"source": "wiki2", "answer": "Using tests for module-2"}}, + {"content": "How to test module-3?", "meta": {"source": "wiki3", "answer": "Using tests for module-3"}}, + {"content": "How to test module-4?", "meta": {"source": "wiki4", "answer": "Using tests for module-4"}}, + {"content": "How to test module-5?", "meta": {"source": "wiki5", "answer": "Using tests for module-5"}}, + ] + + document_store.write_documents(documents) + document_store.update_embeddings(retriever) + + pipeline = FAQPipeline(retriever=retriever) + + output = pipeline.run_batch(queries=["How to test this?", "How to test this?"], params={"Retriever": {"top_k": 3}}) + assert len(output["answers"]) == 2 # 2 queries + assert len(output["answers"][0]) == 3 # 3 answers per query + assert output["queries"][0].startswith("How to") + assert output["answers"][0][0].answer.startswith("Using tests") + + @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) @pytest.mark.parametrize( "document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True @@ -77,6 +99,26 @@ def test_document_search_pipeline(retriever, document_store): assert len(output["documents"]) == 1 +@pytest.mark.parametrize("retriever", ["embedding"], indirect=True) +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +def test_document_search_pipeline_batch(retriever, document_store): + documents = [ + {"content": "Sample text for document-1", "meta": {"source": "wiki1"}}, + {"content": "Sample text for document-2", "meta": {"source": "wiki2"}}, + {"content": "Sample text for document-3", "meta": {"source": "wiki3"}}, + {"content": "Sample text for document-4", "meta": {"source": "wiki4"}}, + {"content": "Sample text for document-5", "meta": {"source": "wiki5"}}, + ] + + document_store.write_documents(documents) + document_store.update_embeddings(retriever) + + pipeline = DocumentSearchPipeline(retriever=retriever) + output = pipeline.run_batch(queries=["How to test this?", "How to test this?"], params={"top_k": 4}) + assert len(output["documents"]) == 2 # 2 queries + assert len(output["documents"][0]) == 4 # 4 docs per query + + @pytest.mark.integration @pytest.mark.parametrize("retriever_with_docs", ["elasticsearch", "dpr", "embedding"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @@ -158,6 +200,35 @@ def test_most_similar_documents_pipeline(retriever, document_store): assert isinstance(document.content, str) +@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True) +def test_most_similar_documents_pipeline_batch(retriever, document_store): + documents = [ + {"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}}, + {"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}}, + {"content": "Sample text for document-3", "meta": {"source": "wiki3"}}, + {"content": "Sample text for document-4", "meta": {"source": "wiki4"}}, + {"content": "Sample text for document-5", "meta": {"source": "wiki5"}}, + ] + + document_store.write_documents(documents) + document_store.update_embeddings(retriever) + + docs_id: list = ["a", "b"] + pipeline = MostSimilarDocumentsPipeline(document_store=document_store) + list_of_documents = pipeline.run_batch(document_ids=docs_id) + + assert len(list_of_documents[0]) > 1 + assert isinstance(list_of_documents, list) + assert len(list_of_documents) == len(docs_id) + + for another_list in list_of_documents: + assert isinstance(another_list, list) + for document in another_list: + assert isinstance(document, Document) + assert isinstance(document.id, str) + assert isinstance(document.content, str) + + @pytest.mark.elasticsearch @pytest.mark.parametrize("document_store_dot_product_with_docs", ["elasticsearch"], indirect=True) def test_join_merge_no_weights(document_store_dot_product_with_docs):