mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
feat: BM25 retrieval for MemoryDocumentStore (#5151)
This commit is contained in:
parent
c068e34954
commit
bc86f57715
0
haystack/preview/components/retrievers/__init__.py
Normal file
0
haystack/preview/components/retrievers/__init__.py
Normal file
84
haystack/preview/components/retrievers/memory.py
Normal file
84
haystack/preview/components/retrievers/memory.py
Normal file
@ -0,0 +1,84 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document, ComponentInput, ComponentOutput
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
|
||||
|
||||
@component
|
||||
class MemoryRetriever:
|
||||
"""
|
||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Input(ComponentInput):
|
||||
"""
|
||||
Input data for the MemoryRetriever component.
|
||||
|
||||
:param query: The query string for the retriever.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
:param scale_score: Whether to scale the BM25 scores or not.
|
||||
:param stores: A dictionary mapping document store names to instances.
|
||||
"""
|
||||
|
||||
query: str
|
||||
filters: Dict[str, Any]
|
||||
top_k: int
|
||||
scale_score: bool
|
||||
stores: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class Output(ComponentOutput):
|
||||
"""
|
||||
Output data from the MemoryRetriever component.
|
||||
|
||||
:param documents: The retrieved documents.
|
||||
"""
|
||||
|
||||
documents: List[Document]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document_store_name: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a MemoryRetriever component.
|
||||
|
||||
:param document_store_name: The name of the MemoryDocumentStore to retrieve documents from.
|
||||
:param filters: A dictionary with filters to narrow down the search space (default is None).
|
||||
:param top_k: The maximum number of documents to retrieve (default is 10).
|
||||
:param scale_score: Whether to scale the BM25 score or not (default is True).
|
||||
|
||||
:raises ValueError: If the specified top_k is not > 0.
|
||||
"""
|
||||
self.document_store_name = document_store_name
|
||||
if top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
self.defaults = {"top_k": top_k, "scale_score": scale_score, "filters": filters or {}}
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
"""
|
||||
Run the MemoryRetriever on the given input data.
|
||||
|
||||
:param data: The input data for the retriever.
|
||||
:return: The retrieved documents.
|
||||
|
||||
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
|
||||
"""
|
||||
if self.document_store_name not in data.stores:
|
||||
raise ValueError(
|
||||
f"MemoryRetriever's document store '{self.document_store_name}' not found "
|
||||
f"in input stores {list(data.stores.keys())}"
|
||||
)
|
||||
document_store = data.stores[self.document_store_name]
|
||||
if not isinstance(document_store, MemoryDocumentStore):
|
||||
raise ValueError("MemoryRetriever can only be used with a MemoryDocumentStore instance.")
|
||||
docs = document_store.bm25_retrieval(
|
||||
query=data.query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
|
||||
)
|
||||
return MemoryRetriever.Output(documents=docs)
|
||||
@ -1,26 +1,49 @@
|
||||
import re
|
||||
from typing import Literal, Any, Dict, List, Optional, Iterable
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import rank_bm25
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores.memory._filters import match
|
||||
from haystack.preview.document_stores.errors import DuplicateDocumentError, MissingDocumentError
|
||||
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DuplicatePolicy = Literal["skip", "overwrite", "fail"]
|
||||
|
||||
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
|
||||
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a SCALING_FACTOR. A
|
||||
# larger SCALING_FACTOR decreases scaled scores. For example, an input of 10 is scaled to 0.99 with SCALING_FACTOR=2
|
||||
# but to 0.78 with SCALING_FACTOR=8 (default). The default was chosen empirically. Increase the default if most
|
||||
# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1.
|
||||
SCALING_FACTOR = 8
|
||||
|
||||
|
||||
class MemoryDocumentStore:
|
||||
"""
|
||||
Stores data in-memory. It's ephemeral and cannot be saved to disk.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self,
|
||||
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
|
||||
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi",
|
||||
bm25_parameters: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the store.
|
||||
"""
|
||||
self.storage = {}
|
||||
self.storage: Dict[str, Document] = {}
|
||||
self.tokenizer = re.compile(bm25_tokenization_regex).findall
|
||||
algorithm_class = getattr(rank_bm25, bm25_algorithm)
|
||||
if algorithm_class is None:
|
||||
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
|
||||
self.bm25_algorithm = algorithm_class
|
||||
self.bm25_parameters = bm25_parameters or {}
|
||||
|
||||
def count_documents(self) -> int:
|
||||
"""
|
||||
@ -142,3 +165,75 @@ class MemoryDocumentStore:
|
||||
if not doc_id in self.storage.keys():
|
||||
raise MissingDocumentError(f"ID '{doc_id}' not found, cannot delete it.")
|
||||
del self.storage[doc_id]
|
||||
|
||||
def bm25_retrieval(
|
||||
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = True
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieves documents that are most relevant to the query using BM25 algorithm.
|
||||
|
||||
:param query: The query string.
|
||||
:param filters: A dictionary with filters to narrow down the search space.
|
||||
:param top_k: The number of top documents to retrieve. Default is 10.
|
||||
:param scale_score: Whether to scale the scores of the retrieved documents. Default is True.
|
||||
:return: A list of the top 'k' documents most relevant to the query.
|
||||
"""
|
||||
if not query:
|
||||
raise ValueError("Query should be a non-empty string")
|
||||
|
||||
# Get all documents that match the user's filters AND are either 'table' or 'text'.
|
||||
# Raises an exception if the user was trying to include other content types.
|
||||
if filters and "content_type" in filters:
|
||||
content_types = filters["content_type"]
|
||||
if isinstance(content_types, str):
|
||||
content_types = [content_types]
|
||||
if any(type_ not in ["text", "table"] for type_ in content_types):
|
||||
raise ValueError(
|
||||
"MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
|
||||
)
|
||||
else:
|
||||
filters = filters or {}
|
||||
filters = {**filters, "content_type": ["text", "table"]}
|
||||
all_documents = self.filter_documents(filters=filters)
|
||||
|
||||
# FIXME: remove this guard after resolving https://github.com/deepset-ai/canals/issues/33
|
||||
top_k = top_k if top_k is not None else 10
|
||||
|
||||
# Lowercase all documents
|
||||
lower_case_documents = []
|
||||
for doc in all_documents:
|
||||
if doc.content_type == "text":
|
||||
lower_case_documents.append(doc.content.lower())
|
||||
elif doc.content_type == "table":
|
||||
str_content = doc.content.astype(str)
|
||||
csv_content = str_content.to_csv(index=False)
|
||||
lower_case_documents.append(csv_content.lower())
|
||||
|
||||
# Tokenize the entire content of the document store
|
||||
tokenized_corpus = [
|
||||
self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...")
|
||||
]
|
||||
if len(tokenized_corpus) == 0:
|
||||
logger.info("No documents found for BM25 retrieval. Returning empty list.")
|
||||
return []
|
||||
|
||||
# initialize BM25
|
||||
bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
|
||||
# tokenize query
|
||||
tokenized_query = self.tokenizer(query.lower())
|
||||
# get scores for the query against the corpus
|
||||
docs_scores = bm25_scorer.get_scores(tokenized_query)
|
||||
if scale_score:
|
||||
docs_scores = [float(expit(np.asarray(score / SCALING_FACTOR))) for score in docs_scores]
|
||||
# get the last top_k indexes and reverse them
|
||||
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
|
||||
|
||||
# Create documents with the BM25 score to return them
|
||||
return_documents = []
|
||||
for i in top_docs_positions:
|
||||
doc = all_documents[i]
|
||||
doc_fields = doc.to_dict()
|
||||
doc_fields["score"] = docs_scores[i]
|
||||
return_document = Document(**doc_fields)
|
||||
return_documents.append(return_document)
|
||||
return return_documents
|
||||
|
||||
0
test/preview/components/retrievers/__init__.py
Normal file
0
test/preview/components/retrievers/__init__.py
Normal file
134
test/preview/components/retrievers/test_memory_retriever.py
Normal file
134
test/preview/components/retrievers/test_memory_retriever.py
Normal file
@ -0,0 +1,134 @@
|
||||
from typing import Dict, Any, List
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline
|
||||
from haystack.preview.components.retrievers.memory import MemoryRetriever
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
|
||||
from test.preview.components.base import BaseTestComponent
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_docs():
|
||||
return [
|
||||
Document.from_dict({"content": "Javascript is a popular programming language"}),
|
||||
Document.from_dict({"content": "Java is a popular programming language"}),
|
||||
Document.from_dict({"content": "Python is a popular programming language"}),
|
||||
Document.from_dict({"content": "Ruby is a popular programming language"}),
|
||||
Document.from_dict({"content": "PHP is a popular programming language"}),
|
||||
]
|
||||
|
||||
|
||||
class Test_MemoryRetriever(BaseTestComponent):
|
||||
@pytest.mark.unit
|
||||
def test_save_load(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(MemoryRetriever(document_store_name="memory"), tmp_path)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_save_load_with_parameters(self, tmp_path):
|
||||
self.assert_can_be_saved_and_loaded_in_pipeline(
|
||||
MemoryRetriever(document_store_name="memory", top_k=5, scale_score=False), tmp_path
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
retriever = MemoryRetriever(document_store_name="memory")
|
||||
assert retriever.document_store_name == "memory"
|
||||
assert retriever.defaults == {"filters": {}, "top_k": 10, "scale_score": True}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
retriever = MemoryRetriever(document_store_name="memory-test", top_k=5, scale_score=False)
|
||||
assert retriever.document_store_name == "memory-test"
|
||||
assert retriever.defaults == {"filters": {}, "top_k": 5, "scale_score": False}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_invalid_top_k_parameter(self):
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got -2"):
|
||||
MemoryRetriever(document_store_name="memory-test", top_k=-2, scale_score=False)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_run(self, mock_docs):
|
||||
top_k = 5
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
mr = MemoryRetriever(document_store_name="memory", top_k=top_k)
|
||||
result: MemoryRetriever.Output = mr.run(data=MemoryRetriever.Input(query="PHP", stores={"memory": ds}))
|
||||
|
||||
assert getattr(result, "documents")
|
||||
assert len(result.documents) == top_k
|
||||
assert result.documents[0].content == "PHP is a popular programming language"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_name(self):
|
||||
# Test invalid run with wrong store name
|
||||
ds = MemoryDocumentStore()
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
with pytest.raises(ValueError, match=r"MemoryRetriever's document store 'memory' not found"):
|
||||
invalid_input_data = MemoryRetriever.Input(
|
||||
query="test", top_k=10, scale_score=True, stores={"invalid_store": ds}
|
||||
)
|
||||
mr.run(invalid_input_data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_run_wrong_store_type(self):
|
||||
# Test invalid run with wrong store type
|
||||
ds = MemoryDocumentStore()
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
with pytest.raises(ValueError, match=r"MemoryRetriever can only be used with a MemoryDocumentStore instance."):
|
||||
invalid_input_data = MemoryRetriever.Input(
|
||||
query="test", top_k=10, scale_score=True, stores={"memory": "not a MemoryDocumentStore"}
|
||||
)
|
||||
mr.run(invalid_input_data)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"query, query_result",
|
||||
[
|
||||
("Javascript", "Javascript is a popular programming language"),
|
||||
("Java", "Java is a popular programming language"),
|
||||
],
|
||||
)
|
||||
def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", mr)
|
||||
pipeline.add_store("memory", ds)
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(query=query)})
|
||||
|
||||
assert result
|
||||
assert "retriever" in result
|
||||
results_docs = result["retriever"].documents
|
||||
assert results_docs
|
||||
assert results_docs[0].content == query_result
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"query, query_result, top_k",
|
||||
[
|
||||
("Javascript", "Javascript is a popular programming language", 1),
|
||||
("Java", "Java is a popular programming language", 2),
|
||||
("Ruby", "Ruby is a popular programming language", 3),
|
||||
],
|
||||
)
|
||||
def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
|
||||
ds = MemoryDocumentStore()
|
||||
ds.write_documents(mock_docs)
|
||||
mr = MemoryRetriever(document_store_name="memory")
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_component("retriever", mr)
|
||||
pipeline.add_store("memory", ds)
|
||||
result: Dict[str, Any] = pipeline.run(data={"retriever": MemoryRetriever.Input(query=query, top_k=top_k)})
|
||||
|
||||
assert result
|
||||
assert "retriever" in result
|
||||
results_docs = result["retriever"].documents
|
||||
assert results_docs
|
||||
assert len(results_docs) == top_k
|
||||
assert results_docs[0].content == query_result
|
||||
@ -1,4 +1,9 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
|
||||
from test.preview.document_stores._base import DocumentStoreBaseTests
|
||||
@ -12,3 +17,191 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
|
||||
@pytest.fixture
|
||||
def docstore(self) -> MemoryDocumentStore:
|
||||
return MemoryDocumentStore()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval(self, docstore):
|
||||
docstore = MemoryDocumentStore()
|
||||
# Tests if the bm25_retrieval method returns the correct document based on the input query.
|
||||
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.bm25_retrieval(query="What languages?", top_k=1, filters={})
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_empty_document_store(self, docstore, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
# Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the store.
|
||||
results = docstore.bm25_retrieval(query="How to test this?", top_k=2)
|
||||
assert len(results) == 0
|
||||
assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_empty_query(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
||||
docstore.write_documents(docs)
|
||||
with pytest.raises(ValueError, match=r"Query should be a non-empty string"):
|
||||
docstore.bm25_retrieval(query="", top_k=1)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_string(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "text"})
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_allowed_doc_type_as_list(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text"]})
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Haystack supports multiple languages"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_two_allowed_doc_type_as_list(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "table"]})
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_string(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
with pytest.raises(
|
||||
ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
|
||||
):
|
||||
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": "audio"})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_only_one_not_allowed_doc_type_as_list(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document(content=pd.DataFrame({"language": ["Python", "Java"]}), content_type="table"),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
with pytest.raises(
|
||||
ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
|
||||
):
|
||||
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["audio"]})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_filter_two_not_all_allowed_doc_type_as_list(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
||||
docs = [
|
||||
Document.from_dict({"content": pd.DataFrame({"language": ["Python", "Java"]}), "content_type": "table"}),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
with pytest.raises(
|
||||
ValueError, match="MemoryDocumentStore can do BM25 retrieval on no other document type than text or table."
|
||||
):
|
||||
docstore.bm25_retrieval(query="Python", top_k=3, filters={"content_type": ["text", "audio"]})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_different_top_k(self, docstore):
|
||||
# Tests if the bm25_retrieval method correctly changes the number of returned documents
|
||||
# based on the top_k parameter.
|
||||
docs = [
|
||||
Document(content="Hello world"),
|
||||
Document(content="Haystack supports multiple languages"),
|
||||
Document(content="Python is a popular programming language"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
# top_k = 2
|
||||
results = docstore.bm25_retrieval(query="languages", top_k=2)
|
||||
assert len(results) == 2
|
||||
|
||||
# top_k = 3
|
||||
results = docstore.bm25_retrieval(query="languages", top_k=3)
|
||||
assert len(results) == 3
|
||||
|
||||
# Test two queries and make sure the results are different
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_two_queries(self, docstore):
|
||||
# Tests if the bm25_retrieval method returns different documents for different queries.
|
||||
docs = [
|
||||
Document(content="Javascript is a popular programming language"),
|
||||
Document(content="Java is a popular programming language"),
|
||||
Document(content="Python is a popular programming language"),
|
||||
Document(content="Ruby is a popular programming language"),
|
||||
Document(content="PHP is a popular programming language"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
results = docstore.bm25_retrieval(query="Java", top_k=1)
|
||||
assert results[0].content == "Java is a popular programming language"
|
||||
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=1)
|
||||
assert results[0].content == "Python is a popular programming language"
|
||||
|
||||
# Test a query, add a new document and make sure results are appropriately updated
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_updated_docs(self, docstore):
|
||||
# Tests if the bm25_retrieval method correctly updates the retrieved documents when new
|
||||
# documents are added to the store.
|
||||
docs = [Document(content="Hello world")]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=1)
|
||||
assert len(results) == 1
|
||||
|
||||
docstore.write_documents([Document(content="Python is a popular programming language")])
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=1)
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Python is a popular programming language"
|
||||
|
||||
docstore.write_documents([Document(content="Java is a popular programming language")])
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=1)
|
||||
assert len(results) == 1
|
||||
assert results[0].content == "Python is a popular programming language"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_scale_score(self, docstore):
|
||||
docs = [Document(content="Python programming"), Document(content="Java programming")]
|
||||
docstore.write_documents(docs)
|
||||
|
||||
results1 = docstore.bm25_retrieval(query="Python", top_k=1, scale_score=True)
|
||||
# Confirm that score is scaled between 0 and 1
|
||||
assert 0 <= results1[0].score <= 1
|
||||
|
||||
# Same query, different scale, scores differ when not scaled
|
||||
results = docstore.bm25_retrieval(query="Python", top_k=1, scale_score=False)
|
||||
assert results[0].score != results1[0].score
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bm25_retrieval_with_table_content(self, docstore):
|
||||
# Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table.
|
||||
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
|
||||
docs = [
|
||||
Document(content=table_content, content_type="table"),
|
||||
Document(content="Gardening", content_type="text"),
|
||||
Document(content="Bird watching", content_type="text"),
|
||||
]
|
||||
docstore.write_documents(docs)
|
||||
results = docstore.bm25_retrieval(query="Java", top_k=1)
|
||||
assert len(results) == 1
|
||||
df = results[0].content
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert df.equals(table_content)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user