From cd66a80ba25b759a3cf4f4a9b6fca4b403f97b6e Mon Sep 17 00:00:00 2001 From: Guest400123064 <31812718+Guest400123064@users.noreply.github.com> Date: Fri, 3 May 2024 08:10:15 -0400 Subject: [PATCH] perf: enhanced `InMemoryDocumentStore` BM25 query efficiency with incremental indexing (#7549) * incorporating better bm25 impl without breaking interface * all three bm25 algos * 1. setting algo post-init not allowed; 2. remove extra underscore for naming consistency; 3. remove unused import * 1. rename attribute name for IDF computation 2. organize document statistics as a dataclass instead of tuple to improve readability * fix score type initialization (int -> float) to pass mypy check * release note included * fixing linting issues and mypy * fixing tests * removing heapq import and cleaning up logging * changing indexing order * adding more tests * increasing tests * removing rank_bm25 from pyproject.toml --------- Co-authored-by: David S. Batista --- .../in_memory/document_store.py | 317 +++++++++++++++--- pyproject.toml | 1 - ...incremental-indexing-ebaf43b7163f3a24.yaml | 7 + test/document_stores/test_in_memory.py | 26 +- 4 files changed, 292 insertions(+), 59 deletions(-) create mode 100644 releasenotes/notes/enhance-inmemorydocumentstore-bm25-incremental-indexing-ebaf43b7163f3a24.yaml diff --git a/haystack/document_stores/in_memory/document_store.py b/haystack/document_stores/in_memory/document_store.py index 3575b3c93..31a516df7 100644 --- a/haystack/document_stores/in_memory/document_store.py +++ b/haystack/document_stores/in_memory/document_store.py @@ -1,9 +1,10 @@ +import math import re -from typing import Any, Dict, Iterable, List, Literal, Optional +from collections import Counter +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np -from haystack_bm25 import rank_bm25 -from tqdm.auto import tqdm from haystack import default_from_dict, default_to_dict, logging from haystack.dataclasses import Document @@ -24,6 +25,19 @@ BM25_SCALING_FACTOR = 8 DOT_PRODUCT_SCALING_FACTOR = 100 +@dataclass +class BM25DocumentStats: + """ + A dataclass for managing document statistics for BM25 retrieval. + + :param freq_token: A Counter of token frequencies in the document. + :param doc_len: Number of tokens in the document. + """ + + freq_token: Dict[str, int] + doc_len: int + + class InMemoryDocumentStore: """ Stores data in-memory. It's ephemeral and cannot be saved to disk. @@ -50,15 +64,206 @@ class InMemoryDocumentStore: To choose the most appropriate function, look for information about your embedding model. """ self.storage: Dict[str, Document] = {} - self._bm25_tokenization_regex = bm25_tokenization_regex + self.bm25_tokenization_regex = bm25_tokenization_regex self.tokenizer = re.compile(bm25_tokenization_regex).findall - algorithm_class = getattr(rank_bm25, bm25_algorithm) - if algorithm_class is None: - raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.") - self.bm25_algorithm = algorithm_class + + self.bm25_algorithm = bm25_algorithm + self.bm25_algorithm_inst = self._dispatch_bm25() self.bm25_parameters = bm25_parameters or {} self.embedding_similarity_function = embedding_similarity_function + # Global BM25 statistics + self._avg_doc_len: float = 0.0 + self._freq_vocab_for_idf: Counter = Counter() + + # Per-document statistics + self._bm25_attr: Dict[str, BM25DocumentStats] = {} + + def _dispatch_bm25(self): + """ + Select the correct BM25 algorithm based on user specification. + + :returns: + The BM25 algorithm method. + """ + table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus} + + if self.bm25_algorithm not in table: + raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.") + return table[self.bm25_algorithm] + + def _tokenize_bm25(self, text: str) -> List[str]: + """ + Tokenize text using the BM25 tokenization regex. + + Here we explicitly create a tokenization method to encapsulate + all pre-processing logic used to create BM25 tokens, such as + lowercasing. This helps track the exact tokenization process + used for BM25 scoring at any given time. + + :param text: + The text to tokenize. + :returns: + A list of tokens. + """ + text = text.lower() + return self.tokenizer(text) + + def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]: + """ + Calculate BM25L scores for the given query and filtered documents. + + :param query: + The query string. + :param documents: + The list of documents to score, should be produced by + the filter_documents method; may be an empty list. + :returns: + A list of tuples, each containing a Document and its BM25L score. + """ + k = self.bm25_parameters.get("k1", 1.5) + b = self.bm25_parameters.get("b", 0.75) + delta = self.bm25_parameters.get("delta", 0.5) + + def _compute_idf(tokens: List[str]) -> Dict[str, float]: + """Per-token IDF computation for all tokens.""" + idf = {} + n_corpus = len(self._bm25_attr) + for tok in tokens: + n = self._freq_vocab_for_idf.get(tok, 0) + idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0) + return idf + + def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float: + """Per-token BM25L computation.""" + freq_term = freq.get(token, 0.0) + ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len) + return (1.0 + k) * (ctd + delta) / (k + ctd + delta) + + idf = _compute_idf(self._tokenize_bm25(query)) + bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents} + + ret = [] + for doc in documents: + doc_stats = bm25_attr[doc.id] + freq = doc_stats.freq_token + doc_len = doc_stats.doc_len + + score = 0.0 + for tok in idf.keys(): # pylint: disable=consider-using-dict-items + score += idf[tok] * _compute_tf(tok, freq, doc_len) + ret.append((doc, score)) + + return ret + + def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]: + """ + Calculate BM25Okapi scores for the given query and filtered documents. + + :param query: + The query string. + :param documents: + The list of documents to score, should be produced by + the filter_documents method; may be an empty list. + :returns: + A list of tuples, each containing a Document and its BM25L score. + """ + k = self.bm25_parameters.get("k1", 1.5) + b = self.bm25_parameters.get("b", 0.75) + epsilon = self.bm25_parameters.get("epsilon", 0.25) + + def _compute_idf(tokens: List[str]) -> Dict[str, float]: + """Per-token IDF computation for all tokens.""" + sum_idf = 0.0 + neg_idf_tokens = [] + + # Although this is a global statistic, we compute it here + # to make the computation more self-contained. And the + # complexity is O(vocab_size), which is acceptable. + idf = {} + for tok, n in self._freq_vocab_for_idf.items(): + idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5)) + sum_idf += idf[tok] + if idf[tok] < 0: + neg_idf_tokens.append(tok) + + eps = epsilon * sum_idf / len(self._freq_vocab_for_idf) + for tok in neg_idf_tokens: + idf[tok] = eps + return {tok: idf.get(tok, 0.0) for tok in tokens} + + def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float: + """Per-token BM25L computation.""" + freq_term = freq.get(token, 0.0) + freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len) + return freq_term * (1.0 + k) / freq_norm + + idf = _compute_idf(self._tokenize_bm25(query)) + bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents} + + ret = [] + for doc in documents: + doc_stats = bm25_attr[doc.id] + freq = doc_stats.freq_token + doc_len = doc_stats.doc_len + + score = 0.0 + for tok in idf.keys(): + score += idf[tok] * _compute_tf(tok, freq, doc_len) + ret.append((doc, score)) + + return ret + + def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]: + """ + Calculate BM25+ scores for the given query and filtered documents. + + This implementation follows the document on BM25 Wikipedia page, + which add 1 (smoothing factor) to document frequency when computing IDF. + + :param query: + The query string. + :param documents: + The list of documents to score, should be produced by + the filter_documents method; may be an empty list. + :returns: + A list of tuples, each containing a Document and its BM25+ score. + """ + k = self.bm25_parameters.get("k1", 1.5) + b = self.bm25_parameters.get("b", 0.75) + delta = self.bm25_parameters.get("delta", 1.0) + + def _compute_idf(tokens: List[str]) -> Dict[str, float]: + """Per-token IDF computation.""" + idf = {} + n_corpus = len(self._bm25_attr) + for tok in tokens: + n = self._freq_vocab_for_idf.get(tok, 0) + idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0) + return idf + + def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float: + """Per-token normalized term frequency.""" + freq_term = freq.get(token, 0.0) + freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len) + return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta + + idf = _compute_idf(self._tokenize_bm25(query)) + bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents} + + ret = [] + for doc in documents: + doc_stats = bm25_attr[doc.id] + freq = doc_stats.freq_token + doc_len = doc_stats.doc_len + + score = 0.0 + for tok in idf.keys(): # pylint: disable=consider-using-dict-items + score += idf[tok] * _compute_tf(tok, freq, doc_len) + ret.append((doc, score)) + + return ret + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -68,8 +273,8 @@ class InMemoryDocumentStore: """ return default_to_dict( self, - bm25_tokenization_regex=self._bm25_tokenization_regex, - bm25_algorithm=self.bm25_algorithm.__name__, + bm25_tokenization_regex=self.bm25_tokenization_regex, + bm25_algorithm=self.bm25_algorithm, bm25_parameters=self.bm25_parameters, embedding_similarity_function=self.embedding_similarity_function, ) @@ -132,7 +337,36 @@ class InMemoryDocumentStore: logger.warning("ID '{document_id}' already exists", document_id=document.id) written_documents -= 1 continue + + # Since the statistics are updated in an incremental manner, + # we need to explicitly remove the existing document to revert + # the statistics before updating them with the new document. + if document.id in self.storage.keys(): + self.delete_documents([document.id]) + + # This processing logic is extracted from the original bm25_retrieval method. + # Since we are creating index incrementally before the first retrieval, + # we need to determine what content to use for indexing here, not at query time. + if document.content is not None: + if document.dataframe is not None: + logger.warning( + "Document '{document_id}' has both text and dataframe content. " + "Using text content for retrieval and skipping dataframe content.", + document_id=document.id, + ) + tokens = self._tokenize_bm25(document.content) + elif document.dataframe is not None: + str_content = document.dataframe.astype(str) + csv_content = str_content.to_csv(index=False) + tokens = self._tokenize_bm25(csv_content) + else: + tokens = [] + self.storage[document.id] = document + + self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens)) + self._freq_vocab_for_idf.update(set(tokens)) + self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1) return written_documents def delete_documents(self, document_ids: List[str]) -> None: @@ -146,6 +380,17 @@ class InMemoryDocumentStore: continue del self.storage[doc_id] + # Update statistics accordingly + doc_stats = self._bm25_attr.pop(doc_id) + freq = doc_stats.freq_token + doc_len = doc_stats.doc_len + + self._freq_vocab_for_idf.subtract(Counter(freq.keys())) + try: + self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr) + except ZeroDivisionError: + self._avg_doc_len = 0 + def bm25_retrieval( self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False ) -> List[Document]: @@ -174,65 +419,33 @@ class InMemoryDocumentStore: filters = {"operator": "AND", "conditions": [content_type_filter, filters]} else: filters = content_type_filter + all_documents = self.filter_documents(filters=filters) - - # Lowercase all documents - lower_case_documents = [] - for doc in all_documents: - if doc.content is None and doc.dataframe is None: - logger.info( - "Document '{document_id}' has no text or dataframe content. Skipping it.", document_id=doc.id - ) - else: - if doc.content is not None: - lower_case_documents.append(doc.content.lower()) - if doc.dataframe is not None: - logger.warning( - "Document '{document_id}' has both text and dataframe content. " - "Using text content and skipping dataframe content.", - document_id=doc.id, - ) - continue - if doc.dataframe is not None: - str_content = doc.dataframe.astype(str) - csv_content = str_content.to_csv(index=False) - lower_case_documents.append(csv_content.lower()) - - # Tokenize the entire content of the DocumentStore - tokenized_corpus = [ - self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...") - ] - if len(tokenized_corpus) == 0: + if len(all_documents) == 0: logger.info("No documents found for BM25 retrieval. Returning empty list.") return [] - # initialize BM25 - bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters) - # tokenize query - tokenized_query = self.tokenizer(query.lower()) - # get scores for the query against the corpus - docs_scores = bm25_scorer.get_scores(tokenized_query) - if scale_score: - docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores] - # get the last top_k indexes and reverse them - top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1] + results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k] # BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False. # It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores. # see https://github.com/deepset-ai/haystack/pull/6889 for more context. - negatives_are_valid = self.bm25_algorithm is rank_bm25.BM25Okapi and not scale_score + negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score # Create documents with the BM25 score to return them return_documents = [] - for i in top_docs_positions: - doc = all_documents[i] - score = docs_scores[i] + for doc, score in results: + if scale_score: + score = expit(score / BM25_SCALING_FACTOR) + if not negatives_are_valid and score <= 0.0: continue + doc_fields = doc.to_dict() doc_fields["score"] = score return_document = Document.from_dict(doc_fields) return_documents.append(return_document) + return return_documents def embedding_retrieval( diff --git a/pyproject.toml b/pyproject.toml index 50e0f5a6e..e8b65ac53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ classifiers = [ ] dependencies = [ "pandas", - "haystack-bm25", "tqdm", "tenacity", "lazy-imports", diff --git a/releasenotes/notes/enhance-inmemorydocumentstore-bm25-incremental-indexing-ebaf43b7163f3a24.yaml b/releasenotes/notes/enhance-inmemorydocumentstore-bm25-incremental-indexing-ebaf43b7163f3a24.yaml new file mode 100644 index 000000000..1c8d3f380 --- /dev/null +++ b/releasenotes/notes/enhance-inmemorydocumentstore-bm25-incremental-indexing-ebaf43b7163f3a24.yaml @@ -0,0 +1,7 @@ +--- +enhancements: + - | + Re-implement `InMemoryDocumentStore` BM25 search with incremental indexing by avoiding re-creating + the entire inverse index for every new query. This change also removes the dependency on + `haystack_bm25`. Please refer to [PR #7549](https://github.com/deepset-ai/haystack/pull/7549) + for the full context. diff --git a/test/document_stores/test_in_memory.py b/test/document_stores/test_in_memory.py index 3b31c13db..1d633b98f 100644 --- a/test/document_stores/test_in_memory.py +++ b/test/document_stores/test_in_memory.py @@ -3,7 +3,6 @@ from unittest.mock import patch import pandas as pd import pytest -from haystack_bm25 import rank_bm25 from haystack import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError @@ -64,9 +63,13 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 store = InMemoryDocumentStore.from_dict(data) mock_regex.compile.assert_called_with("custom_regex") assert store.tokenizer - assert store.bm25_algorithm.__name__ == "BM25Plus" + assert store.bm25_algorithm == "BM25Plus" assert store.bm25_parameters == {"key": "value"} + def test_invalid_bm25_algorithm(self): + with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"): + InMemoryDocumentStore(bm25_algorithm="invalid") + def test_write_documents(self, document_store): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 @@ -113,7 +116,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 results = document_store.bm25_retrieval(query="languages", top_k=3) assert len(results) == 3 - # Test two queries and make sure the results are different + def test_bm25_plus_retrieval(self): + doc_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus") + docs = [ + Document(content="Hello world"), + Document(content="Haystack supports multiple languages"), + Document(content="Python is a popular programming language"), + ] + doc_store.write_documents(docs) + + results = doc_store.bm25_retrieval(query="language", top_k=1) + assert len(results) == 1 + assert results[0].content == "Python is a popular programming language" def test_bm25_retrieval_with_two_queries(self, document_store: InMemoryDocumentStore): # Tests if the bm25_retrieval method returns different documents for different queries. @@ -166,7 +180,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 results = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=False) assert results[0].score != results1[0].score - def test_bm25_retrieval_with_non_scaled_BM25Okapi(self, document_store: InMemoryDocumentStore): + def test_bm25_retrieval_with_non_scaled_BM25Okapi(self): # Highly repetitive documents make BM25Okapi return negative scores, which should not be filtered if the # scores are not scaled docs = [ @@ -188,9 +202,9 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 to try the new features as soon as they are merged.""" ), ] + document_store = InMemoryDocumentStore(bm25_algorithm="BM25Okapi") document_store.write_documents(docs) - document_store.bm25_algorithm = rank_bm25.BM25Okapi results1 = document_store.bm25_retrieval(query="Haystack installation", top_k=10, scale_score=False) assert len(results1) == 3 assert all(res.score < 0.0 for res in results1) @@ -215,11 +229,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]}) document = Document(content="Gardening", dataframe=table_content) docs = [ - document, Document(content="Python"), Document(content="Bird Watching"), Document(content="Gardening"), Document(content="Java"), + document, ] document_store.write_documents(docs) results = document_store.bm25_retrieval(query="Gardening", top_k=2)