Change Document.__eq__ to compare all fields (#6323)

This commit is contained in:
Silvano Cerza 2023-11-16 17:17:43 +01:00 committed by GitHub
parent ff3165b8b8
commit 6dda6e5b2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 9 deletions

View File

@ -82,11 +82,12 @@ class Document(metaclass=_BackwardCompatible):
def __eq__(self, other):
"""
Compares documents for equality. Uses the id to check whether the documents are supposed to be the same.
Compares Documents for equality.
Two Documents are considered equals if their dictionary representation is identical.
"""
if type(self) == type(other):
return self.id == other.id
return False
if type(self) != type(other):
return False
return self.to_dict() == other.to_dict()
def __post_init__(self):
"""

View File

@ -0,0 +1,14 @@
---
preview:
- |
Refactor `Document.__eq__()` so it compares the `Document`s dictionary
representation instead of only their `id`.
Previously this comparison would have unexpectedly worked:
```python
first_doc = Document(id="1", content="Hey!")
second_doc = Document(id="1", content="Hello!")
assert first_doc == second_doc
first_doc.content = "Howdy!"
assert first_doc == second_doc
```
With this change the last comparison would rightly fail.

View File

@ -199,10 +199,10 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Gardening", top_k=2)
assert document in results
assert document.id in [d.id for d in results]
assert "both text and dataframe content" in caplog.text
results = docstore.bm25_retrieval(query="Python", top_k=2)
assert document not in results
assert document.id not in [d.id for d in results]
@pytest.mark.unit
def test_bm25_retrieval_default_filter_for_text_and_dataframes(self, docstore: InMemoryDocumentStore):
@ -217,7 +217,8 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
docs = [Document(), selected_document, Document(content="Bird watching")]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Java", top_k=10, filters={"selected": True})
assert results == [selected_document]
assert len(results) == 1
assert results[0].id == selected_document.id
@pytest.mark.unit
def test_bm25_retrieval_with_filters_keeps_default_filters(self, docstore: InMemoryDocumentStore):
@ -232,7 +233,8 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
docs = [Document(), Document(content="Gardening"), Document(content="Bird watching"), document]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Java", top_k=10, filters={"content": None})
assert results == [document]
assert len(results) == 1
assert results[0].id == document.id
@pytest.mark.unit
def test_bm25_retrieval_with_documents_with_mixed_content(self, docstore: InMemoryDocumentStore):
@ -240,7 +242,8 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
docs = [Document(embedding=[1.0, 2.0, 3.0]), double_document, Document(content="Bird watching")]
docstore.write_documents(docs)
results = docstore.bm25_retrieval(query="Java", top_k=10, filters={"embedding": {"$not": None}})
assert results == [double_document]
assert len(results) == 1
assert results[0].id == double_document.id
@pytest.mark.unit
def test_embedding_retrieval(self):