feat: BM25 retrieval for MemoryDocumentStore (#5151)

This commit is contained in:
Vladimir Blagojevic 2023-06-27 17:42:23 +02:00 committed by GitHub
parent c068e34954
commit bc86f57715
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 509 additions and 3 deletions

View File

@ -0,0 +1,84 @@
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
from haystack.preview import component, Document, ComponentInput, ComponentOutput
from haystack.preview.document_stores import MemoryDocumentStore
@component
class MemoryRetriever:
"""
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
"""
@dataclass
class Input(ComponentInput):
"""
Input data for the MemoryRetriever component.
:param query: The query string for the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
:param scale_score: Whether to scale the BM25 scores or not.
:param stores: A dictionary mapping document store names to instances.
"""
query: str
filters: Dict[str, Any]
top_k: int
scale_score: bool
stores: Dict[str, Any]
@dataclass
class Output(ComponentOutput):
"""
Output data from the MemoryRetriever component.
:param documents: The retrieved documents.
"""
documents: List[Document]
def __init__(
self,
document_store_name: str,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
):
"""
Create a MemoryRetriever component.
:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve (default is 10).
:param scale_score: Whether to scale the BM25 score or not (default is True).
:raises ValueError: If the specified top_k is not > 0.
"""
self.document_store_name = document_store_name
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}
def run(self, data: Input) -> Output:
"""
Run the MemoryRetriever on the given input data.
:param data: The input data for the retriever.
:return: The retrieved documents.
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
if self.document_store_name not in data.stores:
raise ValueError(
f"MemoryRetriever's document store '{self.document_store_name}' not found "
f"in input stores {list(data.stores.keys())}"
)
document_store = data.stores[self.document_store_name]
if not isinstance(document_store, MemoryDocumentStore):
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
docs = document_store.bm25_retrieval(
query=data.query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
)
return MemoryRetriever.Output(documents=docs)

View File

@ -1,26 +1,49 @@
import re
from typing import Literal, Any, Dict, List, Optional, Iterable
import logging
import numpy as np
import rank_bm25
from tqdm.auto import tqdm
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores.memory._filters import match
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__)
DuplicatePolicy = Literal["skip", "overwrite", "fail"]
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A
# larger SCALING_FACTOR decreases scaled scores. For example, an input of 10 is scaled to 0.99 with SCALING_FACTOR=2
# but to 0.78 with SCALING_FACTOR=8 (default). The default was chosen empirically. Increase the default if most
# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1.
SCALING_FACTOR = 8
class MemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
"""
def __init__(self):
def __init__(
self,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi",
bm25_parameters: Optional[Dict] = None,
):
"""
Initializes the store.
"""
self.storage = {}
self.storage: Dict[str, Document] = {}
self.tokenizer = re.compile(bm25_tokenization_regex).findall
algorithm_class = getattr(rank_bm25, bm25_algorithm)
if algorithm_class is None:
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
self.bm25_algorithm = algorithm_class
self.bm25_parameters = bm25_parameters or {}
def count_documents(self) -> int:
"""
@ -142,3 +165,75 @@ class MemoryDocumentStore:
if not doc_id in self.storage.keys():
raise MissingDocumentError(f"ID '{doc_id}' not found, cannot delete it.")
del self.storage[doc_id]
def bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True
) -> List[Document]:
"""
Retrieves documents that are most relevant to the query using BM25 algorithm.
:param query: The query string.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The number of top documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved documents. Default is True.
:return: A list of the top 'k' documents most relevant to the query.
"""
if not query:
raise ValueError("Query should be a non-empty string")
# Get all documents that match the user's filters AND are either 'table' or 'text'.
# Raises an exception if the user was trying to include other content types.
if filters and "content_type" in filters:
content_types = filters["content_type"]
if isinstance(content_types, str):
content_types = [content_types]
if any(type_ not in ["text", "table"] for type_ in content_types):
raise ValueError(
"MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
)
else:
filters = filters or {}
filters = {**filters, "content_type": ["text", "table"]}
all_documents = self.filter_documents(filters=filters)
# FIXME: remove this guard after resolving https://github.com/deepset-ai/canals/issues/33
top_k = top_k if top_k is not None else 10
# Lowercase all documents
lower_case_documents = []
for doc in all_documents:
if doc.content_type == "text":
lower_case_documents.append(doc.content.lower())
elif doc.content_type == "table":
str_content = doc.content.astype(str)
csv_content = str_content.to_csv(index=False)
lower_case_documents.append(csv_content.lower())
# Tokenize the entire content of the document store
tokenized_corpus = [
self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...")
]
if len(tokenized_corpus) == 0:
logger.info("No documents found for BM25 retrieval. Returning empty list.")
return []
# initialize BM25
bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
# tokenize query
tokenized_query = self.tokenizer(query.lower())
# get scores for the query against the corpus
docs_scores = bm25_scorer.get_scores(tokenized_query)
if scale_score:
docs_scores = [float(expit(np.asarray(score / SCALING_FACTOR))) for score in docs_scores]
# get the last top_k indexes and reverse them
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
# Create documents with the BM25 score to return them
return_documents = []
for i in top_docs_positions:
doc = all_documents[i]
doc_fields = doc.to_dict()
doc_fields["score"] = docs_scores[i]
return_document = Document(**doc_fields)
return_documents.append(return_document)
return return_documents

View File

@ -0,0 +1,134 @@
from typing import Dict, Any, List
import pytest
from haystack.preview import Pipeline
from haystack.preview.components.retrievers.memory import MemoryRetriever
from haystack.preview.dataclasses import Document
from haystack.preview.document_stores import MemoryDocumentStore
from test.preview.components.base import BaseTestComponent
@pytest.fixture()
def mock_docs():
return [
Document.from_dict({"content": "Javascript is a popular programming language"}),
Document.from_dict({"content": "Java is a popular programming language"}),
Document.from_dict({"content": "Python is a popular programming language"}),
Document.from_dict({"content": "Ruby is a popular programming language"}),
Document.from_dict({"content": "PHP is a popular programming language"}),
]
class Test_MemoryRetriever(BaseTestComponent):
@pytest.mark.unit
def test_save_load(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path)
@pytest.mark.unit
def test_save_load_with_parameters(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(
MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path
)
@pytest.mark.unit
def test_init_default(self):
retriever = MemoryRetriever(document_store_name="memory")
assert retriever.document_store_name == "memory"
assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True}
@pytest.mark.unit
def test_init_with_parameters(self):
retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False)
assert retriever.document_store_name == "memory-test"
assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False}
@pytest.mark.unit
def test_init_with_invalid_top_k_parameter(self):
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False)
@pytest.mark.unit
def test_valid_run(self, mock_docs):
top_k = 5
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
mr = MemoryRetriever(document_store_name="memory", top_k=top_k)
result: MemoryRetriever.Output = mr.run(data=MemoryRetriever.Input(query="PHP", stores={"memory": ds}))
assert getattr(result, "documents")
assert len(result.documents) == top_k
assert result.documents[0].content == "PHP is a popular programming language"
@pytest.mark.unit
def test_invalid_run_wrong_store_name(self):
# Test invalid run with wrong store name
ds = MemoryDocumentStore()
mr = MemoryRetriever(document_store_name="memory")
with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"):
invalid_input_data = MemoryRetriever.Input(
query="test", top_k=10, scale_score=True, stores={"invalid_store": ds}
)
mr.run(invalid_input_data)
@pytest.mark.unit
def test_invalid_run_wrong_store_type(self):
# Test invalid run with wrong store type
ds = MemoryDocumentStore()
mr = MemoryRetriever(document_store_name="memory")
with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."):
invalid_input_data = MemoryRetriever.Input(
query="test", top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"}
)
mr.run(invalid_input_data)
@pytest.mark.integration
@pytest.mark.parametrize(
"query, query_result",
[
("Javascript", "Javascript is a popular programming language"),
("Java", "Java is a popular programming language"),
],
)
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
mr = MemoryRetriever(document_store_name="memory")
pipeline = Pipeline()
pipeline.add_component("retriever", mr)
pipeline.add_store("memory", ds)
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(query=query)})
assert result
assert "retriever" in result
results_docs = result["retriever"].documents
assert results_docs
assert results_docs[0].content == query_result
@pytest.mark.integration
@pytest.mark.parametrize(
"query, query_result, top_k",
[
("Javascript", "Javascript is a popular programming language", 1),
("Java", "Java is a popular programming language", 2),
("Ruby", "Ruby is a popular programming language", 3),
],
)
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
ds = MemoryDocumentStore()
ds.write_documents(mock_docs)
mr = MemoryRetriever(document_store_name="memory")
pipeline = Pipeline()
pipeline.add_component("retriever", mr)
pipeline.add_store("memory", ds)
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(query=query, top_k=top_k)})
assert result
assert "retriever" in result
results_docs = result["retriever"].documents
assert results_docs
assert len(results_docs) == top_k
assert results_docs[0].content == query_result

View File

@ -1,4 +1,9 @@
import logging
import pandas as pd
import pytest
from haystack.preview import Document
from haystack.preview.document_stores import MemoryDocumentStore
from test.preview.document_stores._base import DocumentStoreBaseTests
@ -12,3 +17,191 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
@pytest.fixture
def docstore(self) -> MemoryDocumentStore:
return MemoryDocumentStore()
@pytest.mark.unit
def test_bm25_retrieval(self, docstore):
docstore = MemoryDocumentStore()
# Tests if the bm25_retrieval method returns the correct document based on the input query.
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="What languages?", top_k=1, filters={})
assert len(results) == 1
assert results[0].content == "Haystack supports multiple languages"
@pytest.mark.unit
def test_bm25_retrieval_with_empty_document_store(self, docstore, caplog):
caplog.set_level(logging.INFO)
# Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store.
results = docstore.bm25_retrieval(query="How to test this?", top_k=2)
assert len(results) == 0
assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text
@pytest.mark.unit
def test_bm25_retrieval_empty_query(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
docstore.write_documents(docs)
with pytest.raises(ValueError, match=r"Query should be a non-empty string"):
docstore.bm25_retrieval(query="", top_k=1)
@pytest.mark.unit
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
Document(content="Haystack supports multiple languages"),
]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "text"})
assert len(results) == 1
assert results[0].content == "Haystack supports multiple languages"
@pytest.mark.unit
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
Document(content="Haystack supports multiple languages"),
]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text"]})
assert len(results) == 1
assert results[0].content == "Haystack supports multiple languages"
@pytest.mark.unit
def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
Document(content="Haystack supports multiple languages"),
]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "table"]})
assert len(results) == 2
@pytest.mark.unit
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
Document(content="Haystack supports multiple languages"),
]
docstore.write_documents(docs)
with pytest.raises(
ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
):
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"})
@pytest.mark.unit
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
Document(content="Haystack supports multiple languages"),
]
docstore.write_documents(docs)
with pytest.raises(
ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
):
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]})
@pytest.mark.unit
def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore):
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
docs = [
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
Document(content="Haystack supports multiple languages"),
]
docstore.write_documents(docs)
with pytest.raises(
ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
):
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]})
@pytest.mark.unit
def test_bm25_retrieval_with_different_top_k(self, docstore):
# Tests if the bm25_retrieval method correctly changes the number of returned documents
# based on the top_k parameter.
docs = [
Document(content="Hello world"),
Document(content="Haystack supports multiple languages"),
Document(content="Python is a popular programming language"),
]
docstore.write_documents(docs)
# top_k = 2
results = docstore.bm25_retrieval(query="languages", top_k=2)
assert len(results) == 2
# top_k = 3
results = docstore.bm25_retrieval(query="languages", top_k=3)
assert len(results) == 3
# Test two queries and make sure the results are different
@pytest.mark.unit
def test_bm25_retrieval_with_two_queries(self, docstore):
# Tests if the bm25_retrieval method returns different documents for different queries.
docs = [
Document(content="Javascript is a popular programming language"),
Document(content="Java is a popular programming language"),
Document(content="Python is a popular programming language"),
Document(content="Ruby is a popular programming language"),
Document(content="PHP is a popular programming language"),
]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Java", top_k=1)
assert results[0].content == "Java is a popular programming language"
results = docstore.bm25_retrieval(query="Python", top_k=1)
assert results[0].content == "Python is a popular programming language"
# Test a query, add a new document and make sure results are appropriately updated
@pytest.mark.unit
def test_bm25_retrieval_with_updated_docs(self, docstore):
# Tests if the bm25_retrieval method correctly updates the retrieved documents when new
# documents are added to the store.
docs = [Document(content="Hello world")]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Python", top_k=1)
assert len(results) == 1
docstore.write_documents([Document(content="Python is a popular programming language")])
results = docstore.bm25_retrieval(query="Python", top_k=1)
assert len(results) == 1
assert results[0].content == "Python is a popular programming language"
docstore.write_documents([Document(content="Java is a popular programming language")])
results = docstore.bm25_retrieval(query="Python", top_k=1)
assert len(results) == 1
assert results[0].content == "Python is a popular programming language"
@pytest.mark.unit
def test_bm25_retrieval_with_scale_score(self, docstore):
docs = [Document(content="Python programming"), Document(content="Java programming")]
docstore.write_documents(docs)
results1 = docstore.bm25_retrieval(query="Python", top_k=1, scale_score=True)
# Confirm that score is scaled between 0 and 1
assert 0 <= results1[0].score <= 1
# Same query, different scale, scores differ when not scaled
results = docstore.bm25_retrieval(query="Python", top_k=1, scale_score=False)
assert results[0].score != results1[0].score
@pytest.mark.unit
def test_bm25_retrieval_with_table_content(self, docstore):
# Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table.
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
docs = [
Document(content=table_content, content_type="table"),
Document(content="Gardening", content_type="text"),
Document(content="Bird watching", content_type="text"),
]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Java", top_k=1)
assert len(results) == 1
df = results[0].content
assert isinstance(df, pd.DataFrame)
assert df.equals(table_content)