perf: enhanced InMemoryDocumentStore BM25 query efficiency with incremental indexing (#7549)

* incorporating better bm25 impl without breaking interface

* all three bm25 algos

* 1. setting algo post-init not allowed; 2. remove extra underscore for naming consistency; 3. remove unused import

* 1. rename attribute name for IDF computation 2. organize document statistics as a dataclass instead of tuple to improve readability

* fix score type initialization (int -> float) to pass mypy check

* release note included

* fixing linting issues and mypy

* fixing tests

* removing heapq import and cleaning up logging

* changing indexing order

* adding more tests

* increasing tests

* removing rank_bm25 from pyproject.toml

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
Guest400123064 2024-05-03 08:10:15 -04:00 committed by GitHub
parent 48c7c6ad26
commit cd66a80ba2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 292 additions and 59 deletions

View File

@ -1,9 +1,10 @@
import math
import re
from typing import Any, Dict, Iterable, List, Literal, Optional
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
import numpy as np
from haystack_bm25 import rank_bm25
from tqdm.auto import tqdm
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
@ -24,6 +25,19 @@ BM25_SCALING_FACTOR = 8
DOT_PRODUCT_SCALING_FACTOR = 100
@dataclass
class BM25DocumentStats:
"""
A dataclass for managing document statistics for BM25 retrieval.
:param freq_token: A Counter of token frequencies in the document.
:param doc_len: Number of tokens in the document.
"""
freq_token: Dict[str, int]
doc_len: int
class InMemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
@ -50,15 +64,206 @@ class InMemoryDocumentStore:
To choose the most appropriate function, look for information about your embedding model.
"""
self.storage: Dict[str, Document] = {}
self._bm25_tokenization_regex = bm25_tokenization_regex
self.bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall
algorithm_class = getattr(rank_bm25, bm25_algorithm)
if algorithm_class is None:
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
self.bm25_algorithm = algorithm_class
self.bm25_algorithm = bm25_algorithm
self.bm25_algorithm_inst = self._dispatch_bm25()
self.bm25_parameters = bm25_parameters or {}
self.embedding_similarity_function = embedding_similarity_function
# Global BM25 statistics
self._avg_doc_len: float = 0.0
self._freq_vocab_for_idf: Counter = Counter()
# Per-document statistics
self._bm25_attr: Dict[str, BM25DocumentStats] = {}
def _dispatch_bm25(self):
"""
Select the correct BM25 algorithm based on user specification.
:returns:
The BM25 algorithm method.
"""
table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus}
if self.bm25_algorithm not in table:
raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.")
return table[self.bm25_algorithm]
def _tokenize_bm25(self, text: str) -> List[str]:
"""
Tokenize text using the BM25 tokenization regex.
Here we explicitly create a tokenization method to encapsulate
all pre-processing logic used to create BM25 tokens, such as
lowercasing. This helps track the exact tokenization process
used for BM25 scoring at any given time.
:param text:
The text to tokenize.
:returns:
A list of tokens.
"""
text = text.lower()
return self.tokenizer(text)
def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
"""
Calculate BM25L scores for the given query and filtered documents.
:param query:
The query string.
:param documents:
The list of documents to score, should be produced by
the filter_documents method; may be an empty list.
:returns:
A list of tuples, each containing a Document and its BM25L score.
"""
k = self.bm25_parameters.get("k1", 1.5)
b = self.bm25_parameters.get("b", 0.75)
delta = self.bm25_parameters.get("delta", 0.5)
def _compute_idf(tokens: List[str]) -> Dict[str, float]:
"""Per-token IDF computation for all tokens."""
idf = {}
n_corpus = len(self._bm25_attr)
for tok in tokens:
n = self._freq_vocab_for_idf.get(tok, 0)
idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0)
return idf
def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
"""Per-token BM25L computation."""
freq_term = freq.get(token, 0.0)
ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len)
return (1.0 + k) * (ctd + delta) / (k + ctd + delta)
idf = _compute_idf(self._tokenize_bm25(query))
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
ret = []
for doc in documents:
doc_stats = bm25_attr[doc.id]
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len
score = 0.0
for tok in idf.keys(): # pylint: disable=consider-using-dict-items
score += idf[tok] * _compute_tf(tok, freq, doc_len)
ret.append((doc, score))
return ret
def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
"""
Calculate BM25Okapi scores for the given query and filtered documents.
:param query:
The query string.
:param documents:
The list of documents to score, should be produced by
the filter_documents method; may be an empty list.
:returns:
A list of tuples, each containing a Document and its BM25L score.
"""
k = self.bm25_parameters.get("k1", 1.5)
b = self.bm25_parameters.get("b", 0.75)
epsilon = self.bm25_parameters.get("epsilon", 0.25)
def _compute_idf(tokens: List[str]) -> Dict[str, float]:
"""Per-token IDF computation for all tokens."""
sum_idf = 0.0
neg_idf_tokens = []
# Although this is a global statistic, we compute it here
# to make the computation more self-contained. And the
# complexity is O(vocab_size), which is acceptable.
idf = {}
for tok, n in self._freq_vocab_for_idf.items():
idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5))
sum_idf += idf[tok]
if idf[tok] < 0:
neg_idf_tokens.append(tok)
eps = epsilon * sum_idf / len(self._freq_vocab_for_idf)
for tok in neg_idf_tokens:
idf[tok] = eps
return {tok: idf.get(tok, 0.0) for tok in tokens}
def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
"""Per-token BM25L computation."""
freq_term = freq.get(token, 0.0)
freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len)
return freq_term * (1.0 + k) / freq_norm
idf = _compute_idf(self._tokenize_bm25(query))
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
ret = []
for doc in documents:
doc_stats = bm25_attr[doc.id]
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len
score = 0.0
for tok in idf.keys():
score += idf[tok] * _compute_tf(tok, freq, doc_len)
ret.append((doc, score))
return ret
def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
"""
Calculate BM25+ scores for the given query and filtered documents.
This implementation follows the document on BM25 Wikipedia page,
which add 1 (smoothing factor) to document frequency when computing IDF.
:param query:
The query string.
:param documents:
The list of documents to score, should be produced by
the filter_documents method; may be an empty list.
:returns:
A list of tuples, each containing a Document and its BM25+ score.
"""
k = self.bm25_parameters.get("k1", 1.5)
b = self.bm25_parameters.get("b", 0.75)
delta = self.bm25_parameters.get("delta", 1.0)
def _compute_idf(tokens: List[str]) -> Dict[str, float]:
"""Per-token IDF computation."""
idf = {}
n_corpus = len(self._bm25_attr)
for tok in tokens:
n = self._freq_vocab_for_idf.get(tok, 0)
idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0)
return idf
def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float:
"""Per-token normalized term frequency."""
freq_term = freq.get(token, 0.0)
freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len)
return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta
idf = _compute_idf(self._tokenize_bm25(query))
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
ret = []
for doc in documents:
doc_stats = bm25_attr[doc.id]
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len
score = 0.0
for tok in idf.keys(): # pylint: disable=consider-using-dict-items
score += idf[tok] * _compute_tf(tok, freq, doc_len)
ret.append((doc, score))
return ret
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
@ -68,8 +273,8 @@ class InMemoryDocumentStore:
"""
return default_to_dict(
self,
bm25_tokenization_regex=self._bm25_tokenization_regex,
bm25_algorithm=self.bm25_algorithm.__name__,
bm25_tokenization_regex=self.bm25_tokenization_regex,
bm25_algorithm=self.bm25_algorithm,
bm25_parameters=self.bm25_parameters,
embedding_similarity_function=self.embedding_similarity_function,
)
@ -132,7 +337,36 @@ class InMemoryDocumentStore:
logger.warning("ID '{document_id}' already exists", document_id=document.id)
written_documents -= 1
continue
# Since the statistics are updated in an incremental manner,
# we need to explicitly remove the existing document to revert
# the statistics before updating them with the new document.
if document.id in self.storage.keys():
self.delete_documents([document.id])
# This processing logic is extracted from the original bm25_retrieval method.
# Since we are creating index incrementally before the first retrieval,
# we need to determine what content to use for indexing here, not at query time.
if document.content is not None:
if document.dataframe is not None:
logger.warning(
"Document '{document_id}' has both text and dataframe content. "
"Using text content for retrieval and skipping dataframe content.",
document_id=document.id,
)
tokens = self._tokenize_bm25(document.content)
elif document.dataframe is not None:
str_content = document.dataframe.astype(str)
csv_content = str_content.to_csv(index=False)
tokens = self._tokenize_bm25(csv_content)
else:
tokens = []
self.storage[document.id] = document
self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens))
self._freq_vocab_for_idf.update(set(tokens))
self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1)
return written_documents
def delete_documents(self, document_ids: List[str]) -> None:
@ -146,6 +380,17 @@ class InMemoryDocumentStore:
continue
del self.storage[doc_id]
# Update statistics accordingly
doc_stats = self._bm25_attr.pop(doc_id)
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len
self._freq_vocab_for_idf.subtract(Counter(freq.keys()))
try:
self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr)
except ZeroDivisionError:
self._avg_doc_len = 0
def bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
) -> List[Document]:
@ -174,65 +419,33 @@ class InMemoryDocumentStore:
filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
else:
filters = content_type_filter
all_documents = self.filter_documents(filters=filters)
# Lowercase all documents
lower_case_documents = []
for doc in all_documents:
if doc.content is None and doc.dataframe is None:
logger.info(
"Document '{document_id}' has no text or dataframe content. Skipping it.", document_id=doc.id
)
else:
if doc.content is not None:
lower_case_documents.append(doc.content.lower())
if doc.dataframe is not None:
logger.warning(
"Document '{document_id}' has both text and dataframe content. "
"Using text content and skipping dataframe content.",
document_id=doc.id,
)
continue
if doc.dataframe is not None:
str_content = doc.dataframe.astype(str)
csv_content = str_content.to_csv(index=False)
lower_case_documents.append(csv_content.lower())
# Tokenize the entire content of the DocumentStore
tokenized_corpus = [
self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...")
]
if len(tokenized_corpus) == 0:
if len(all_documents) == 0:
logger.info("No documents found for BM25 retrieval. Returning empty list.")
return []
# initialize BM25
bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
# tokenize query
tokenized_query = self.tokenizer(query.lower())
# get scores for the query against the corpus
docs_scores = bm25_scorer.get_scores(tokenized_query)
if scale_score:
docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores]
# get the last top_k indexes and reverse them
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k]
# BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
# It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
# see https://github.com/deepset-ai/haystack/pull/6889 for more context.
negatives_are_valid = self.bm25_algorithm is rank_bm25.BM25Okapi and not scale_score
negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score
# Create documents with the BM25 score to return them
return_documents = []
for i in top_docs_positions:
doc = all_documents[i]
score = docs_scores[i]
for doc, score in results:
if scale_score:
score = expit(score / BM25_SCALING_FACTOR)
if not negatives_are_valid and score <= 0.0:
continue
doc_fields = doc.to_dict()
doc_fields["score"] = score
return_document = Document.from_dict(doc_fields)
return_documents.append(return_document)
return return_documents
def embedding_retrieval(

View File

@ -47,7 +47,6 @@ classifiers = [
]
dependencies = [
"pandas",
"haystack-bm25",
"tqdm",
"tenacity",
"lazy-imports",

View File

@ -0,0 +1,7 @@
---
enhancements:
- |
Re-implement `InMemoryDocumentStore` BM25 search with incremental indexing by avoiding re-creating
the entire inverse index for every new query. This change also removes the dependency on
`haystack_bm25`. Please refer to [PR #7549](https://github.com/deepset-ai/haystack/pull/7549)
for the full context.

View File

@ -3,7 +3,6 @@ from unittest.mock import patch
import pandas as pd
import pytest
from haystack_bm25 import rank_bm25
from haystack import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
@ -64,9 +63,13 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
store = InMemoryDocumentStore.from_dict(data)
mock_regex.compile.assert_called_with("custom_regex")
assert store.tokenizer
assert store.bm25_algorithm.__name__ == "BM25Plus"
assert store.bm25_algorithm == "BM25Plus"
assert store.bm25_parameters == {"key": "value"}
def test_invalid_bm25_algorithm(self):
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
InMemoryDocumentStore(bm25_algorithm="invalid")
def test_write_documents(self, document_store):
docs = [Document(id="1")]
assert document_store.write_documents(docs) == 1
@ -113,7 +116,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
results = document_store.bm25_retrieval(query="languages", top_k=3)
assert len(results) == 3
# Test two queries and make sure the results are different
def test_bm25_plus_retrieval(self):
doc_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus")
docs = [
Document(content="Hello world"),
Document(content="Haystack supports multiple languages"),
Document(content="Python is a popular programming language"),
]
doc_store.write_documents(docs)
results = doc_store.bm25_retrieval(query="language", top_k=1)
assert len(results) == 1
assert results[0].content == "Python is a popular programming language"
def test_bm25_retrieval_with_two_queries(self, document_store: InMemoryDocumentStore):
# Tests if the bm25_retrieval method returns different documents for different queries.
@ -166,7 +180,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
results = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=False)
assert results[0].score != results1[0].score
def test_bm25_retrieval_with_non_scaled_BM25Okapi(self, document_store: InMemoryDocumentStore):
def test_bm25_retrieval_with_non_scaled_BM25Okapi(self):
# Highly repetitive documents make BM25Okapi return negative scores, which should not be filtered if the
# scores are not scaled
docs = [
@ -188,9 +202,9 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
to try the new features as soon as they are merged."""
),
]
document_store = InMemoryDocumentStore(bm25_algorithm="BM25Okapi")
document_store.write_documents(docs)
document_store.bm25_algorithm = rank_bm25.BM25Okapi
results1 = document_store.bm25_retrieval(query="Haystack installation", top_k=10, scale_score=False)
assert len(results1) == 3
assert all(res.score < 0.0 for res in results1)
@ -215,11 +229,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
document = Document(content="Gardening", dataframe=table_content)
docs = [
document,
Document(content="Python"),
Document(content="Bird Watching"),
Document(content="Gardening"),
Document(content="Java"),
document,
]
document_store.write_documents(docs)
results = document_store.bm25_retrieval(query="Gardening", top_k=2)