mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
perf: enhanced InMemoryDocumentStore BM25 query efficiency with incremental indexing (#7549)
* incorporating better bm25 impl without breaking interface * all three bm25 algos * 1. setting algo post-init not allowed; 2. remove extra underscore for naming consistency; 3. remove unused import * 1. rename attribute name for IDF computation 2. organize document statistics as a dataclass instead of tuple to improve readability * fix score type initialization (int -> float) to pass mypy check * release note included * fixing linting issues and mypy * fixing tests * removing heapq import and cleaning up logging * changing indexing order * adding more tests * increasing tests * removing rank_bm25 from pyproject.toml --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
parent
48c7c6ad26
commit
cd66a80ba2
@ -1,9 +1,10 @@
|
||||
import math
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from haystack_bm25 import rank_bm25
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from haystack import default_from_dict, default_to_dict, logging
|
||||
from haystack.dataclasses import Document
|
||||
@ -24,6 +25,19 @@ BM25_SCALING_FACTOR = 8
|
||||
DOT_PRODUCT_SCALING_FACTOR = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class BM25DocumentStats:
|
||||
"""
|
||||
A dataclass for managing document statistics for BM25 retrieval.
|
||||
|
||||
:param freq_token: A Counter of token frequencies in the document.
|
||||
:param doc_len: Number of tokens in the document.
|
||||
"""
|
||||
|
||||
freq_token: Dict[str, int]
|
||||
doc_len: int
|
||||
|
||||
|
||||
class InMemoryDocumentStore:
|
||||
"""
|
||||
Stores data in-memory. It's ephemeral and cannot be saved to disk.
|
||||
@ -50,15 +64,206 @@ class InMemoryDocumentStore:
|
||||
To choose the most appropriate function, look for information about your embedding model.
|
||||
"""
|
||||
self.storage: Dict[str, Document] = {}
|
||||
self._bm25_tokenization_regex = bm25_tokenization_regex
|
||||
self.bm25_tokenization_regex = bm25_tokenization_regex
|
||||
self.tokenizer = re.compile(bm25_tokenization_regex).findall
|
||||
algorithm_class = getattr(rank_bm25, bm25_algorithm)
|
||||
if algorithm_class is None:
|
||||
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
|
||||
self.bm25_algorithm = algorithm_class
|
||||
|
||||
self.bm25_algorithm = bm25_algorithm
|
||||
self.bm25_algorithm_inst = self._dispatch_bm25()
|
||||
self.bm25_parameters = bm25_parameters or {}
|
||||
self.embedding_similarity_function = embedding_similarity_function
|
||||
|
||||
# Global BM25 statistics
|
||||
self._avg_doc_len: float = 0.0
|
||||
self._freq_vocab_for_idf: Counter = Counter()
|
||||
|
||||
# Per-document statistics
|
||||
self._bm25_attr: Dict[str, BM25DocumentStats] = {}
|
||||
|
||||
def _dispatch_bm25(self):
|
||||
"""
|
||||
Select the correct BM25 algorithm based on user specification.
|
||||
|
||||
:returns:
|
||||
The BM25 algorithm method.
|
||||
"""
|
||||
table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus}
|
||||
|
||||
if self.bm25_algorithm not in table:
|
||||
raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.")
|
||||
return table[self.bm25_algorithm]
|
||||
|
||||
def _tokenize_bm25(self, text: str) -> List[str]:
|
||||
"""
|
||||
Tokenize text using the BM25 tokenization regex.
|
||||
|
||||
Here we explicitly create a tokenization method to encapsulate
|
||||
all pre-processing logic used to create BM25 tokens, such as
|
||||
lowercasing. This helps track the exact tokenization process
|
||||
used for BM25 scoring at any given time.
|
||||
|
||||
:param text:
|
||||
The text to tokenize.
|
||||
:returns:
|
||||
A list of tokens.
|
||||
"""
|
||||
text = text.lower()
|
||||
return self.tokenizer(text)
|
||||
|
||||
def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Calculate BM25L scores for the given query and filtered documents.
|
||||
|
||||
:param query:
|
||||
The query string.
|
||||
:param documents:
|
||||
The list of documents to score, should be produced by
|
||||
the filter_documents method; may be an empty list.
|
||||
:returns:
|
||||
A list of tuples, each containing a Document and its BM25L score.
|
||||
"""
|
||||
k = self.bm25_parameters.get("k1", 1.5)
|
||||
b = self.bm25_parameters.get("b", 0.75)
|
||||
delta = self.bm25_parameters.get("delta", 0.5)
|
||||
|
||||
def _compute_idf(tokens: List[str]) -> Dict[str, float]:
|
||||
"""Per-token IDF computation for all tokens."""
|
||||
idf = {}
|
||||
n_corpus = len(self._bm25_attr)
|
||||
for tok in tokens:
|
||||
n = self._freq_vocab_for_idf.get(tok, 0)
|
||||
idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0)
|
||||
return idf
|
||||
|
||||
def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
|
||||
"""Per-token BM25L computation."""
|
||||
freq_term = freq.get(token, 0.0)
|
||||
ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len)
|
||||
return (1.0 + k) * (ctd + delta) / (k + ctd + delta)
|
||||
|
||||
idf = _compute_idf(self._tokenize_bm25(query))
|
||||
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
|
||||
|
||||
ret = []
|
||||
for doc in documents:
|
||||
doc_stats = bm25_attr[doc.id]
|
||||
freq = doc_stats.freq_token
|
||||
doc_len = doc_stats.doc_len
|
||||
|
||||
score = 0.0
|
||||
for tok in idf.keys(): # pylint: disable=consider-using-dict-items
|
||||
score += idf[tok] * _compute_tf(tok, freq, doc_len)
|
||||
ret.append((doc, score))
|
||||
|
||||
return ret
|
||||
|
||||
def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Calculate BM25Okapi scores for the given query and filtered documents.
|
||||
|
||||
:param query:
|
||||
The query string.
|
||||
:param documents:
|
||||
The list of documents to score, should be produced by
|
||||
the filter_documents method; may be an empty list.
|
||||
:returns:
|
||||
A list of tuples, each containing a Document and its BM25L score.
|
||||
"""
|
||||
k = self.bm25_parameters.get("k1", 1.5)
|
||||
b = self.bm25_parameters.get("b", 0.75)
|
||||
epsilon = self.bm25_parameters.get("epsilon", 0.25)
|
||||
|
||||
def _compute_idf(tokens: List[str]) -> Dict[str, float]:
|
||||
"""Per-token IDF computation for all tokens."""
|
||||
sum_idf = 0.0
|
||||
neg_idf_tokens = []
|
||||
|
||||
# Although this is a global statistic, we compute it here
|
||||
# to make the computation more self-contained. And the
|
||||
# complexity is O(vocab_size), which is acceptable.
|
||||
idf = {}
|
||||
for tok, n in self._freq_vocab_for_idf.items():
|
||||
idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5))
|
||||
sum_idf += idf[tok]
|
||||
if idf[tok] < 0:
|
||||
neg_idf_tokens.append(tok)
|
||||
|
||||
eps = epsilon * sum_idf / len(self._freq_vocab_for_idf)
|
||||
for tok in neg_idf_tokens:
|
||||
idf[tok] = eps
|
||||
return {tok: idf.get(tok, 0.0) for tok in tokens}
|
||||
|
||||
def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
|
||||
"""Per-token BM25L computation."""
|
||||
freq_term = freq.get(token, 0.0)
|
||||
freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len)
|
||||
return freq_term * (1.0 + k) / freq_norm
|
||||
|
||||
idf = _compute_idf(self._tokenize_bm25(query))
|
||||
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
|
||||
|
||||
ret = []
|
||||
for doc in documents:
|
||||
doc_stats = bm25_attr[doc.id]
|
||||
freq = doc_stats.freq_token
|
||||
doc_len = doc_stats.doc_len
|
||||
|
||||
score = 0.0
|
||||
for tok in idf.keys():
|
||||
score += idf[tok] * _compute_tf(tok, freq, doc_len)
|
||||
ret.append((doc, score))
|
||||
|
||||
return ret
|
||||
|
||||
def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Calculate BM25+ scores for the given query and filtered documents.
|
||||
|
||||
This implementation follows the document on BM25 Wikipedia page,
|
||||
which add 1 (smoothing factor) to document frequency when computing IDF.
|
||||
|
||||
:param query:
|
||||
The query string.
|
||||
:param documents:
|
||||
The list of documents to score, should be produced by
|
||||
the filter_documents method; may be an empty list.
|
||||
:returns:
|
||||
A list of tuples, each containing a Document and its BM25+ score.
|
||||
"""
|
||||
k = self.bm25_parameters.get("k1", 1.5)
|
||||
b = self.bm25_parameters.get("b", 0.75)
|
||||
delta = self.bm25_parameters.get("delta", 1.0)
|
||||
|
||||
def _compute_idf(tokens: List[str]) -> Dict[str, float]:
|
||||
"""Per-token IDF computation."""
|
||||
idf = {}
|
||||
n_corpus = len(self._bm25_attr)
|
||||
for tok in tokens:
|
||||
n = self._freq_vocab_for_idf.get(tok, 0)
|
||||
idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0)
|
||||
return idf
|
||||
|
||||
def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float:
|
||||
"""Per-token normalized term frequency."""
|
||||
freq_term = freq.get(token, 0.0)
|
||||
freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len)
|
||||
return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta
|
||||
|
||||
idf = _compute_idf(self._tokenize_bm25(query))
|
||||
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}
|
||||
|
||||
ret = []
|
||||
for doc in documents:
|
||||
doc_stats = bm25_attr[doc.id]
|
||||
freq = doc_stats.freq_token
|
||||
doc_len = doc_stats.doc_len
|
||||
|
||||
score = 0.0
|
||||
for tok in idf.keys(): # pylint: disable=consider-using-dict-items
|
||||
score += idf[tok] * _compute_tf(tok, freq, doc_len)
|
||||
ret.append((doc, score))
|
||||
|
||||
return ret
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
@ -68,8 +273,8 @@ class InMemoryDocumentStore:
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
bm25_tokenization_regex=self._bm25_tokenization_regex,
|
||||
bm25_algorithm=self.bm25_algorithm.__name__,
|
||||
bm25_tokenization_regex=self.bm25_tokenization_regex,
|
||||
bm25_algorithm=self.bm25_algorithm,
|
||||
bm25_parameters=self.bm25_parameters,
|
||||
embedding_similarity_function=self.embedding_similarity_function,
|
||||
)
|
||||
@ -132,7 +337,36 @@ class InMemoryDocumentStore:
|
||||
logger.warning("ID '{document_id}' already exists", document_id=document.id)
|
||||
written_documents -= 1
|
||||
continue
|
||||
|
||||
# Since the statistics are updated in an incremental manner,
|
||||
# we need to explicitly remove the existing document to revert
|
||||
# the statistics before updating them with the new document.
|
||||
if document.id in self.storage.keys():
|
||||
self.delete_documents([document.id])
|
||||
|
||||
# This processing logic is extracted from the original bm25_retrieval method.
|
||||
# Since we are creating index incrementally before the first retrieval,
|
||||
# we need to determine what content to use for indexing here, not at query time.
|
||||
if document.content is not None:
|
||||
if document.dataframe is not None:
|
||||
logger.warning(
|
||||
"Document '{document_id}' has both text and dataframe content. "
|
||||
"Using text content for retrieval and skipping dataframe content.",
|
||||
document_id=document.id,
|
||||
)
|
||||
tokens = self._tokenize_bm25(document.content)
|
||||
elif document.dataframe is not None:
|
||||
str_content = document.dataframe.astype(str)
|
||||
csv_content = str_content.to_csv(index=False)
|
||||
tokens = self._tokenize_bm25(csv_content)
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
self.storage[document.id] = document
|
||||
|
||||
self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens))
|
||||
self._freq_vocab_for_idf.update(set(tokens))
|
||||
self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1)
|
||||
return written_documents
|
||||
|
||||
def delete_documents(self, document_ids: List[str]) -> None:
|
||||
@ -146,6 +380,17 @@ class InMemoryDocumentStore:
|
||||
continue
|
||||
del self.storage[doc_id]
|
||||
|
||||
# Update statistics accordingly
|
||||
doc_stats = self._bm25_attr.pop(doc_id)
|
||||
freq = doc_stats.freq_token
|
||||
doc_len = doc_stats.doc_len
|
||||
|
||||
self._freq_vocab_for_idf.subtract(Counter(freq.keys()))
|
||||
try:
|
||||
self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr)
|
||||
except ZeroDivisionError:
|
||||
self._avg_doc_len = 0
|
||||
|
||||
def bm25_retrieval(
|
||||
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
|
||||
) -> List[Document]:
|
||||
@ -174,65 +419,33 @@ class InMemoryDocumentStore:
|
||||
filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
|
||||
else:
|
||||
filters = content_type_filter
|
||||
|
||||
all_documents = self.filter_documents(filters=filters)
|
||||
|
||||
# Lowercase all documents
|
||||
lower_case_documents = []
|
||||
for doc in all_documents:
|
||||
if doc.content is None and doc.dataframe is None:
|
||||
logger.info(
|
||||
"Document '{document_id}' has no text or dataframe content. Skipping it.", document_id=doc.id
|
||||
)
|
||||
else:
|
||||
if doc.content is not None:
|
||||
lower_case_documents.append(doc.content.lower())
|
||||
if doc.dataframe is not None:
|
||||
logger.warning(
|
||||
"Document '{document_id}' has both text and dataframe content. "
|
||||
"Using text content and skipping dataframe content.",
|
||||
document_id=doc.id,
|
||||
)
|
||||
continue
|
||||
if doc.dataframe is not None:
|
||||
str_content = doc.dataframe.astype(str)
|
||||
csv_content = str_content.to_csv(index=False)
|
||||
lower_case_documents.append(csv_content.lower())
|
||||
|
||||
# Tokenize the entire content of the DocumentStore
|
||||
tokenized_corpus = [
|
||||
self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...")
|
||||
]
|
||||
if len(tokenized_corpus) == 0:
|
||||
if len(all_documents) == 0:
|
||||
logger.info("No documents found for BM25 retrieval. Returning empty list.")
|
||||
return []
|
||||
|
||||
# initialize BM25
|
||||
bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
|
||||
# tokenize query
|
||||
tokenized_query = self.tokenizer(query.lower())
|
||||
# get scores for the query against the corpus
|
||||
docs_scores = bm25_scorer.get_scores(tokenized_query)
|
||||
if scale_score:
|
||||
docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores]
|
||||
# get the last top_k indexes and reverse them
|
||||
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
|
||||
results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k]
|
||||
|
||||
# BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
|
||||
# It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
|
||||
# see https://github.com/deepset-ai/haystack/pull/6889 for more context.
|
||||
negatives_are_valid = self.bm25_algorithm is rank_bm25.BM25Okapi and not scale_score
|
||||
negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score
|
||||
|
||||
# Create documents with the BM25 score to return them
|
||||
return_documents = []
|
||||
for i in top_docs_positions:
|
||||
doc = all_documents[i]
|
||||
score = docs_scores[i]
|
||||
for doc, score in results:
|
||||
if scale_score:
|
||||
score = expit(score / BM25_SCALING_FACTOR)
|
||||
|
||||
if not negatives_are_valid and score <= 0.0:
|
||||
continue
|
||||
|
||||
doc_fields = doc.to_dict()
|
||||
doc_fields["score"] = score
|
||||
return_document = Document.from_dict(doc_fields)
|
||||
return_documents.append(return_document)
|
||||
|
||||
return return_documents
|
||||
|
||||
def embedding_retrieval(
|
||||
|
||||
@ -47,7 +47,6 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"pandas",
|
||||
"haystack-bm25",
|
||||
"tqdm",
|
||||
"tenacity",
|
||||
"lazy-imports",
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Re-implement `InMemoryDocumentStore` BM25 search with incremental indexing by avoiding re-creating
|
||||
the entire inverse index for every new query. This change also removes the dependency on
|
||||
`haystack_bm25`. Please refer to [PR #7549](https://github.com/deepset-ai/haystack/pull/7549)
|
||||
for the full context.
|
||||
@ -3,7 +3,6 @@ from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from haystack_bm25 import rank_bm25
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
|
||||
@ -64,9 +63,13 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
store = InMemoryDocumentStore.from_dict(data)
|
||||
mock_regex.compile.assert_called_with("custom_regex")
|
||||
assert store.tokenizer
|
||||
assert store.bm25_algorithm.__name__ == "BM25Plus"
|
||||
assert store.bm25_algorithm == "BM25Plus"
|
||||
assert store.bm25_parameters == {"key": "value"}
|
||||
|
||||
def test_invalid_bm25_algorithm(self):
|
||||
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
||||
InMemoryDocumentStore(bm25_algorithm="invalid")
|
||||
|
||||
def test_write_documents(self, document_store):
|
||||
docs = [Document(id="1")]
|
||||
assert document_store.write_documents(docs) == 1
|
||||
@ -113,7 +116,18 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
results = document_store.bm25_retrieval(query="languages", top_k=3)
|
||||
assert len(results) == 3
|
||||
|
||||
# Test two queries and make sure the results are different
|
||||
def test_bm25_plus_retrieval(self):
|
||||
doc_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus")
|
||||
docs = [
|
||||
Document(content="Hello world"),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
Document(content="Python is a popular programming language"),
|
||||
]
|
||||
doc_store.write_documents(docs)
|
||||
|
||||
results = doc_store.bm25_retrieval(query="language", top_k=1)
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Python is a popular programming language"
|
||||
|
||||
def test_bm25_retrieval_with_two_queries(self, document_store: InMemoryDocumentStore):
|
||||
# Tests if the bm25_retrieval method returns different documents for different queries.
|
||||
@ -166,7 +180,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
results = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=False)
|
||||
assert results[0].score != results1[0].score
|
||||
|
||||
def test_bm25_retrieval_with_non_scaled_BM25Okapi(self, document_store: InMemoryDocumentStore):
|
||||
def test_bm25_retrieval_with_non_scaled_BM25Okapi(self):
|
||||
# Highly repetitive documents make BM25Okapi return negative scores, which should not be filtered if the
|
||||
# scores are not scaled
|
||||
docs = [
|
||||
@ -188,9 +202,9 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
to try the new features as soon as they are merged."""
|
||||
),
|
||||
]
|
||||
document_store = InMemoryDocumentStore(bm25_algorithm="BM25Okapi")
|
||||
document_store.write_documents(docs)
|
||||
|
||||
document_store.bm25_algorithm = rank_bm25.BM25Okapi
|
||||
results1 = document_store.bm25_retrieval(query="Haystack installation", top_k=10, scale_score=False)
|
||||
assert len(results1) == 3
|
||||
assert all(res.score < 0.0 for res in results1)
|
||||
@ -215,11 +229,11 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
||||
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
|
||||
document = Document(content="Gardening", dataframe=table_content)
|
||||
docs = [
|
||||
document,
|
||||
Document(content="Python"),
|
||||
Document(content="Bird Watching"),
|
||||
Document(content="Gardening"),
|
||||
Document(content="Java"),
|
||||
document,
|
||||
]
|
||||
document_store.write_documents(docs)
|
||||
results = document_store.bm25_retrieval(query="Gardening", top_k=2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user