From bc86f5771549d64d8f5a92f31ca58a2b682d485f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 27 Jun 2023 17:42:23 +0200 Subject: [PATCH] feat: BM25 retrieval for `MemoryDocumentStore` (#5151) --- .../preview/components/retrievers/__init__.py | 0 .../preview/components/retrievers/memory.py | 84 ++++++++ .../document_stores/memory/document_store.py | 101 ++++++++- .../preview/components/retrievers/__init__.py | 0 .../retrievers/test_memory_retriever.py | 134 ++++++++++++ test/preview/document_stores/test_memory.py | 193 ++++++++++++++++++ 6 files changed, 509 insertions(+), 3 deletions(-) create mode 100644 haystack/preview/components/retrievers/__init__.py create mode 100644 haystack/preview/components/retrievers/memory.py create mode 100644 test/preview/components/retrievers/__init__.py create mode 100644 test/preview/components/retrievers/test_memory_retriever.py diff --git a/haystack/preview/components/retrievers/__init__.py b/haystack/preview/components/retrievers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/haystack/preview/components/retrievers/memory.py b/haystack/preview/components/retrievers/memory.py new file mode 100644 index 000000000..54d0ce8b5 --- /dev/null +++ b/haystack/preview/components/retrievers/memory.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass +from typing import Dict, List, Any, Optional + +from haystack.preview import component, Document, ComponentInput, ComponentOutput +from haystack.preview.document_stores import MemoryDocumentStore + + +@component +class MemoryRetriever: + """ + A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm. + """ + + @dataclass + class Input(ComponentInput): + """ + Input data for the MemoryRetriever component. + + :param query: The query string for the retriever. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to return. + :param scale_score: Whether to scale the BM25 scores or not. + :param stores: A dictionary mapping document store names to instances. + """ + + query: str + filters: Dict[str, Any] + top_k: int + scale_score: bool + stores: Dict[str, Any] + + @dataclass + class Output(ComponentOutput): + """ + Output data from the MemoryRetriever component. + + :param documents: The retrieved documents. + """ + + documents: List[Document] + + def __init__( + self, + document_store_name: str, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, + ): + """ + Create a MemoryRetriever component. + + :param document_store_name: The name of the MemoryDocumentStore to retrieve documents from. + :param filters: A dictionary with filters to narrow down the search space (default is None). + :param top_k: The maximum number of documents to retrieve (default is 10). + :param scale_score: Whether to scale the BM25 score or not (default is True). + + :raises ValueError: If the specified top_k is not > 0. + """ + self.document_store_name = document_store_name + if top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}} + + def run(self, data: Input) -> Output: + """ + Run the MemoryRetriever on the given input data. + + :param data: The input data for the retriever. + :return: The retrieved documents. + + :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. + """ + if self.document_store_name not in data.stores: + raise ValueError( + f"MemoryRetriever's document store '{self.document_store_name}' not found " + f"in input stores {list(data.stores.keys())}" + ) + document_store = data.stores[self.document_store_name] + if not isinstance(document_store, MemoryDocumentStore): + raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.") + docs = document_store.bm25_retrieval( + query=data.query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score + ) + return MemoryRetriever.Output(documents=docs) diff --git a/haystack/preview/document_stores/memory/document_store.py b/haystack/preview/document_stores/memory/document_store.py index a3eadb9a7..4e6f47e90 100644 --- a/haystack/preview/document_stores/memory/document_store.py +++ b/haystack/preview/document_stores/memory/document_store.py @@ -1,26 +1,49 @@ +import re from typing import Literal, Any, Dict, List, Optional, Iterable import logging +import numpy as np +import rank_bm25 +from tqdm.auto import tqdm + from haystack.preview.dataclasses import Document from haystack.preview.document_stores.memory._filters import match from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError - +from haystack.utils.scipy_utils import expit logger = logging.getLogger(__name__) DuplicatePolicy = Literal["skip", "overwrite", "fail"] +# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to +# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A +# larger SCALING_FACTOR decreases scaled scores. For example, an input of 10 is scaled to 0.99 with SCALING_FACTOR=2 +# but to 0.78 with SCALING_FACTOR=8 (default). The default was chosen empirically. Increase the default if most +# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1. +SCALING_FACTOR = 8 + class MemoryDocumentStore: """ Stores data in-memory. It's ephemeral and cannot be saved to disk. """ - def __init__(self): + def __init__( + self, + bm25_tokenization_regex: str = r"(?u)\b\w\w+\b", + bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi", + bm25_parameters: Optional[Dict] = None, + ): """ Initializes the store. """ - self.storage = {} + self.storage: Dict[str, Document] = {} + self.tokenizer = re.compile(bm25_tokenization_regex).findall + algorithm_class = getattr(rank_bm25, bm25_algorithm) + if algorithm_class is None: + raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.") + self.bm25_algorithm = algorithm_class + self.bm25_parameters = bm25_parameters or {} def count_documents(self) -> int: """ @@ -142,3 +165,75 @@ class MemoryDocumentStore: if not doc_id in self.storage.keys(): raise MissingDocumentError(f"ID '{doc_id}' not found, cannot delete it.") del self.storage[doc_id] + + def bm25_retrieval( + self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True + ) -> List[Document]: + """ + Retrieves documents that are most relevant to the query using BM25 algorithm. + + :param query: The query string. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The number of top documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved documents. Default is True. + :return: A list of the top 'k' documents most relevant to the query. + """ + if not query: + raise ValueError("Query should be a non-empty string") + + # Get all documents that match the user's filters AND are either 'table' or 'text'. + # Raises an exception if the user was trying to include other content types. + if filters and "content_type" in filters: + content_types = filters["content_type"] + if isinstance(content_types, str): + content_types = [content_types] + if any(type_ not in ["text", "table"] for type_ in content_types): + raise ValueError( + "MemoryDocumentStore can do BM25 retrieval on no other document type than text or table." + ) + else: + filters = filters or {} + filters = {**filters, "content_type": ["text", "table"]} + all_documents = self.filter_documents(filters=filters) + + # FIXME: remove this guard after resolving https://github.com/deepset-ai/canals/issues/33 + top_k = top_k if top_k is not None else 10 + + # Lowercase all documents + lower_case_documents = [] + for doc in all_documents: + if doc.content_type == "text": + lower_case_documents.append(doc.content.lower()) + elif doc.content_type == "table": + str_content = doc.content.astype(str) + csv_content = str_content.to_csv(index=False) + lower_case_documents.append(csv_content.lower()) + + # Tokenize the entire content of the document store + tokenized_corpus = [ + self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...") + ] + if len(tokenized_corpus) == 0: + logger.info("No documents found for BM25 retrieval. Returning empty list.") + return [] + + # initialize BM25 + bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters) + # tokenize query + tokenized_query = self.tokenizer(query.lower()) + # get scores for the query against the corpus + docs_scores = bm25_scorer.get_scores(tokenized_query) + if scale_score: + docs_scores = [float(expit(np.asarray(score / SCALING_FACTOR))) for score in docs_scores] + # get the last top_k indexes and reverse them + top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1] + + # Create documents with the BM25 score to return them + return_documents = [] + for i in top_docs_positions: + doc = all_documents[i] + doc_fields = doc.to_dict() + doc_fields["score"] = docs_scores[i] + return_document = Document(**doc_fields) + return_documents.append(return_document) + return return_documents diff --git a/test/preview/components/retrievers/__init__.py b/test/preview/components/retrievers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/preview/components/retrievers/test_memory_retriever.py b/test/preview/components/retrievers/test_memory_retriever.py new file mode 100644 index 000000000..33963a6e5 --- /dev/null +++ b/test/preview/components/retrievers/test_memory_retriever.py @@ -0,0 +1,134 @@ +from typing import Dict, Any, List + +import pytest + +from haystack.preview import Pipeline +from haystack.preview.components.retrievers.memory import MemoryRetriever +from haystack.preview.dataclasses import Document +from haystack.preview.document_stores import MemoryDocumentStore + +from test.preview.components.base import BaseTestComponent + + +@pytest.fixture() +def mock_docs(): + return [ + Document.from_dict({"content": "Javascript is a popular programming language"}), + Document.from_dict({"content": "Java is a popular programming language"}), + Document.from_dict({"content": "Python is a popular programming language"}), + Document.from_dict({"content": "Ruby is a popular programming language"}), + Document.from_dict({"content": "PHP is a popular programming language"}), + ] + + +class Test_MemoryRetriever(BaseTestComponent): + @pytest.mark.unit + def test_save_load(self, tmp_path): + self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path) + + @pytest.mark.unit + def test_save_load_with_parameters(self, tmp_path): + self.assert_can_be_saved_and_loaded_in_pipeline( + MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path + ) + + @pytest.mark.unit + def test_init_default(self): + retriever = MemoryRetriever(document_store_name="memory") + assert retriever.document_store_name == "memory" + assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True} + + @pytest.mark.unit + def test_init_with_parameters(self): + retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False) + assert retriever.document_store_name == "memory-test" + assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False} + + @pytest.mark.unit + def test_init_with_invalid_top_k_parameter(self): + with pytest.raises(ValueError, match="top_k must be > 0, but got -2"): + MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False) + + @pytest.mark.unit + def test_valid_run(self, mock_docs): + top_k = 5 + ds = MemoryDocumentStore() + ds.write_documents(mock_docs) + mr = MemoryRetriever(document_store_name="memory", top_k=top_k) + result: MemoryRetriever.Output = mr.run(data=MemoryRetriever.Input(query="PHP", stores={"memory": ds})) + + assert getattr(result, "documents") + assert len(result.documents) == top_k + assert result.documents[0].content == "PHP is a popular programming language" + + @pytest.mark.unit + def test_invalid_run_wrong_store_name(self): + # Test invalid run with wrong store name + ds = MemoryDocumentStore() + mr = MemoryRetriever(document_store_name="memory") + with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"): + invalid_input_data = MemoryRetriever.Input( + query="test", top_k=10, scale_score=True, stores={"invalid_store": ds} + ) + mr.run(invalid_input_data) + + @pytest.mark.unit + def test_invalid_run_wrong_store_type(self): + # Test invalid run with wrong store type + ds = MemoryDocumentStore() + mr = MemoryRetriever(document_store_name="memory") + with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."): + invalid_input_data = MemoryRetriever.Input( + query="test", top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"} + ) + mr.run(invalid_input_data) + + @pytest.mark.integration + @pytest.mark.parametrize( + "query, query_result", + [ + ("Javascript", "Javascript is a popular programming language"), + ("Java", "Java is a popular programming language"), + ], + ) + def test_run_with_pipeline(self, mock_docs, query: str, query_result: str): + ds = MemoryDocumentStore() + ds.write_documents(mock_docs) + mr = MemoryRetriever(document_store_name="memory") + + pipeline = Pipeline() + pipeline.add_component("retriever", mr) + pipeline.add_store("memory", ds) + result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(query=query)}) + + assert result + assert "retriever" in result + results_docs = result["retriever"].documents + assert results_docs + assert results_docs[0].content == query_result + + @pytest.mark.integration + @pytest.mark.parametrize( + "query, query_result, top_k", + [ + ("Javascript", "Javascript is a popular programming language", 1), + ("Java", "Java is a popular programming language", 2), + ("Ruby", "Ruby is a popular programming language", 3), + ], + ) + def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int): + ds = MemoryDocumentStore() + ds.write_documents(mock_docs) + mr = MemoryRetriever(document_store_name="memory") + + pipeline = Pipeline() + pipeline.add_component("retriever", mr) + pipeline.add_store("memory", ds) + result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(query=query, top_k=top_k)}) + + assert result + assert "retriever" in result + results_docs = result["retriever"].documents + assert results_docs + assert len(results_docs) == top_k + assert results_docs[0].content == query_result diff --git a/test/preview/document_stores/test_memory.py b/test/preview/document_stores/test_memory.py index 42617485c..1344234bd 100644 --- a/test/preview/document_stores/test_memory.py +++ b/test/preview/document_stores/test_memory.py @@ -1,4 +1,9 @@ +import logging + +import pandas as pd import pytest + +from haystack.preview import Document from haystack.preview.document_stores import MemoryDocumentStore from test.preview.document_stores._base import DocumentStoreBaseTests @@ -12,3 +17,191 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): @pytest.fixture def docstore(self) -> MemoryDocumentStore: return MemoryDocumentStore() + + @pytest.mark.unit + def test_bm25_retrieval(self, docstore): + docstore = MemoryDocumentStore() + # Tests if the bm25_retrieval method returns the correct document based on the input query. + docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] + docstore.write_documents(docs) + results = docstore.bm25_retrieval(query="What languages?", top_k=1, filters={}) + assert len(results) == 1 + assert results[0].content == "Haystack supports multiple languages" + + @pytest.mark.unit + def test_bm25_retrieval_with_empty_document_store(self, docstore, caplog): + caplog.set_level(logging.INFO) + # Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store. + results = docstore.bm25_retrieval(query="How to test this?", top_k=2) + assert len(results) == 0 + assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text + + @pytest.mark.unit + def test_bm25_retrieval_empty_query(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] + docstore.write_documents(docs) + with pytest.raises(ValueError, match=r"Query should be a non-empty string"): + docstore.bm25_retrieval(query="", top_k=1) + + @pytest.mark.unit + def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [ + Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), + Document(content="Haystack supports multiple languages"), + ] + docstore.write_documents(docs) + results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "text"}) + assert len(results) == 1 + assert results[0].content == "Haystack supports multiple languages" + + @pytest.mark.unit + def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [ + Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), + Document(content="Haystack supports multiple languages"), + ] + docstore.write_documents(docs) + results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text"]}) + assert len(results) == 1 + assert results[0].content == "Haystack supports multiple languages" + + @pytest.mark.unit + def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [ + Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), + Document(content="Haystack supports multiple languages"), + ] + docstore.write_documents(docs) + results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "table"]}) + assert len(results) == 2 + + @pytest.mark.unit + def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [ + Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"), + Document(content="Haystack supports multiple languages"), + ] + docstore.write_documents(docs) + with pytest.raises( + ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table." + ): + docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"}) + + @pytest.mark.unit + def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [ + Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"), + Document(content="Haystack supports multiple languages"), + ] + docstore.write_documents(docs) + with pytest.raises( + ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table." + ): + docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]}) + + @pytest.mark.unit + def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore): + # Tests if the bm25_retrieval method returns a document when the query is an empty string. + docs = [ + Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}), + Document(content="Haystack supports multiple languages"), + ] + docstore.write_documents(docs) + with pytest.raises( + ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table." + ): + docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]}) + + @pytest.mark.unit + def test_bm25_retrieval_with_different_top_k(self, docstore): + # Tests if the bm25_retrieval method correctly changes the number of returned documents + # based on the top_k parameter. + docs = [ + Document(content="Hello world"), + Document(content="Haystack supports multiple languages"), + Document(content="Python is a popular programming language"), + ] + docstore.write_documents(docs) + + # top_k = 2 + results = docstore.bm25_retrieval(query="languages", top_k=2) + assert len(results) == 2 + + # top_k = 3 + results = docstore.bm25_retrieval(query="languages", top_k=3) + assert len(results) == 3 + + # Test two queries and make sure the results are different + @pytest.mark.unit + def test_bm25_retrieval_with_two_queries(self, docstore): + # Tests if the bm25_retrieval method returns different documents for different queries. + docs = [ + Document(content="Javascript is a popular programming language"), + Document(content="Java is a popular programming language"), + Document(content="Python is a popular programming language"), + Document(content="Ruby is a popular programming language"), + Document(content="PHP is a popular programming language"), + ] + docstore.write_documents(docs) + + results = docstore.bm25_retrieval(query="Java", top_k=1) + assert results[0].content == "Java is a popular programming language" + + results = docstore.bm25_retrieval(query="Python", top_k=1) + assert results[0].content == "Python is a popular programming language" + + # Test a query, add a new document and make sure results are appropriately updated + @pytest.mark.unit + def test_bm25_retrieval_with_updated_docs(self, docstore): + # Tests if the bm25_retrieval method correctly updates the retrieved documents when new + # documents are added to the store. + docs = [Document(content="Hello world")] + docstore.write_documents(docs) + + results = docstore.bm25_retrieval(query="Python", top_k=1) + assert len(results) == 1 + + docstore.write_documents([Document(content="Python is a popular programming language")]) + results = docstore.bm25_retrieval(query="Python", top_k=1) + assert len(results) == 1 + assert results[0].content == "Python is a popular programming language" + + docstore.write_documents([Document(content="Java is a popular programming language")]) + results = docstore.bm25_retrieval(query="Python", top_k=1) + assert len(results) == 1 + assert results[0].content == "Python is a popular programming language" + + @pytest.mark.unit + def test_bm25_retrieval_with_scale_score(self, docstore): + docs = [Document(content="Python programming"), Document(content="Java programming")] + docstore.write_documents(docs) + + results1 = docstore.bm25_retrieval(query="Python", top_k=1, scale_score=True) + # Confirm that score is scaled between 0 and 1 + assert 0 <= results1[0].score <= 1 + + # Same query, different scale, scores differ when not scaled + results = docstore.bm25_retrieval(query="Python", top_k=1, scale_score=False) + assert results[0].score != results1[0].score + + @pytest.mark.unit + def test_bm25_retrieval_with_table_content(self, docstore): + # Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table. + table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]}) + docs = [ + Document(content=table_content, content_type="table"), + Document(content="Gardening", content_type="text"), + Document(content="Bird watching", content_type="text"), + ] + docstore.write_documents(docs) + results = docstore.bm25_retrieval(query="Java", top_k=1) + assert len(results) == 1 + df = results[0].content + assert isinstance(df, pd.DataFrame) + assert df.equals(table_content)