fix!: InMemoryBM25Retriever no longer returns documents that have a score of 0.0 (#6717)

* fix!: `InMemoryBM25Retriever` no longer returns documents that have a score of 0.0

Also update tests to accommodate the new behavior.

* Remove superfluous code
This commit is contained in:
Madeesh Kannan 2024-01-12 17:50:55 +01:00 committed by GitHub
parent 4647f2a506
commit a5189dd035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 14 deletions

View File

@ -212,8 +212,11 @@ class InMemoryDocumentStore:
return_documents = [] return_documents = []
for i in top_docs_positions: for i in top_docs_positions:
doc = all_documents[i] doc = all_documents[i]
score = docs_scores[i]
if score <= 0.0:
continue
doc_fields = doc.to_dict() doc_fields = doc.to_dict()
doc_fields["score"] = docs_scores[i] doc_fields["score"] = score
return_document = Document.from_dict(doc_fields) return_document = Document.from_dict(doc_fields)
return_documents.append(return_document) return_documents.append(return_document)
return return_documents return return_documents

View File

@ -0,0 +1,3 @@
---
fixes:
- Prevent InMemoryBM25Retriever from returning documents with a score of 0.0.

View File

@ -113,15 +113,14 @@ class TestMemoryBM25Retriever:
InMemoryBM25Retriever.from_dict(data) InMemoryBM25Retriever.from_dict(data)
def test_retriever_valid_run(self, mock_docs): def test_retriever_valid_run(self, mock_docs):
top_k = 5
ds = InMemoryDocumentStore() ds = InMemoryDocumentStore()
ds.write_documents(mock_docs) ds.write_documents(mock_docs)
retriever = InMemoryBM25Retriever(ds, top_k=top_k) retriever = InMemoryBM25Retriever(ds, top_k=5)
result = retriever.run(query="PHP") result = retriever.run(query="PHP")
assert "documents" in result assert "documents" in result
assert len(result["documents"]) == top_k assert len(result["documents"]) == 1
assert result["documents"][0].content == "PHP is a popular programming language" assert result["documents"][0].content == "PHP is a popular programming language"
def test_invalid_run_wrong_store_type(self): def test_invalid_run_wrong_store_type(self):
@ -174,5 +173,5 @@ class TestMemoryBM25Retriever:
assert "retriever" in result assert "retriever" in result
results_docs = result["retriever"]["documents"] results_docs = result["retriever"]["documents"]
assert results_docs assert results_docs
assert len(results_docs) == top_k assert len(results_docs) == 1
assert results_docs[0].content == query_result assert results_docs[0].content == query_result

View File

@ -5,8 +5,8 @@ import pandas as pd
import pytest import pytest
from haystack import Document from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.testing.document_store import DocumentStoreBaseTests from haystack.testing.document_store import DocumentStoreBaseTests
@ -17,7 +17,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
@pytest.fixture @pytest.fixture
def document_store(self) -> InMemoryDocumentStore: def document_store(self) -> InMemoryDocumentStore:
return InMemoryDocumentStore() return InMemoryDocumentStore(bm25_algorithm="BM25L")
def test_to_dict(self): def test_to_dict(self):
store = InMemoryDocumentStore() store = InMemoryDocumentStore()
@ -73,7 +73,6 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
document_store.write_documents(docs) document_store.write_documents(docs)
def test_bm25_retrieval(self, document_store: InMemoryDocumentStore): def test_bm25_retrieval(self, document_store: InMemoryDocumentStore):
document_store = InMemoryDocumentStore()
# Tests if the bm25_retrieval method returns the correct document based on the input query. # Tests if the bm25_retrieval method returns the correct document based on the input query.
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")] docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
document_store.write_documents(docs) document_store.write_documents(docs)
@ -106,7 +105,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
document_store.write_documents(docs) document_store.write_documents(docs)
# top_k = 2 # top_k = 2
results = document_store.bm25_retrieval(query="languages", top_k=2) results = document_store.bm25_retrieval(query="language", top_k=2)
assert len(results) == 2 assert len(results) == 2
# top_k = 3 # top_k = 3
@ -141,7 +140,7 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
document_store.write_documents(docs) document_store.write_documents(docs)
results = document_store.bm25_retrieval(query="Python", top_k=1) results = document_store.bm25_retrieval(query="Python", top_k=1)
assert len(results) == 1 assert len(results) == 0
document_store.write_documents([Document(content="Python is a popular programming language")]) document_store.write_documents([Document(content="Python is a popular programming language")])
results = document_store.bm25_retrieval(query="Python", top_k=1) results = document_store.bm25_retrieval(query="Python", top_k=1)
@ -199,10 +198,10 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
docs = [Document(), Document(content="Gardening"), Document(content="Bird watching")] docs = [Document(), Document(content="Gardening"), Document(content="Bird watching")]
document_store.write_documents(docs) document_store.write_documents(docs)
results = document_store.bm25_retrieval(query="doesn't matter, top_k is 10", top_k=10) results = document_store.bm25_retrieval(query="doesn't matter, top_k is 10", top_k=10)
assert len(results) == 2 assert len(results) == 0
def test_bm25_retrieval_with_filters(self, document_store: InMemoryDocumentStore): def test_bm25_retrieval_with_filters(self, document_store: InMemoryDocumentStore):
selected_document = Document(content="Gardening", meta={"selected": True}) selected_document = Document(content="Java is, well...", meta={"selected": True})
docs = [Document(), selected_document, Document(content="Bird watching")] docs = [Document(), selected_document, Document(content="Bird watching")]
document_store.write_documents(docs) document_store.write_documents(docs)
results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"selected": True}) results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"selected": True})
@ -224,10 +223,10 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
assert results[0].id == document.id assert results[0].id == document.id
def test_bm25_retrieval_with_documents_with_mixed_content(self, document_store: InMemoryDocumentStore): def test_bm25_retrieval_with_documents_with_mixed_content(self, document_store: InMemoryDocumentStore):
double_document = Document(content="Gardening", embedding=[1.0, 2.0, 3.0]) double_document = Document(content="Gardening is a hobby", embedding=[1.0, 2.0, 3.0])
docs = [Document(embedding=[1.0, 2.0, 3.0]), double_document, Document(content="Bird watching")] docs = [Document(embedding=[1.0, 2.0, 3.0]), double_document, Document(content="Bird watching")]
document_store.write_documents(docs) document_store.write_documents(docs)
results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"embedding": {"$not": None}}) results = document_store.bm25_retrieval(query="Gardening", top_k=10, filters={"embedding": {"$not": None}})
assert len(results) == 1 assert len(results) == 1
assert results[0].id == double_document.id assert results[0].id == double_document.id