2024-05-09 15:40:36 +02:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2023-06-27 17:42:23 +02:00
|
|
|
import logging
|
2023-08-11 14:45:56 +02:00
|
|
|
from unittest.mock import patch
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
import pandas as pd
|
2023-04-13 09:36:23 +02:00
|
|
|
import pytest
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack import Document
|
2024-01-10 21:20:42 +01:00
|
|
|
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
|
2024-01-12 17:50:55 +01:00
|
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.testing.document_store import DocumentStoreBaseTests
|
2023-04-13 09:36:23 +02:00
|
|
|
|
|
|
|
|
2023-10-31 12:44:04 +01:00
|
|
|
class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904
|
2023-04-13 09:36:23 +02:00
|
|
|
"""
|
2023-10-17 16:15:16 +02:00
|
|
|
Test InMemoryDocumentStore's specific features
|
2023-04-13 09:36:23 +02:00
|
|
|
"""
|
|
|
|
|
|
|
|
@pytest.fixture
|
2023-11-20 17:41:48 +01:00
|
|
|
def document_store(self) -> InMemoryDocumentStore:
|
2024-01-12 17:50:55 +01:00
|
|
|
return InMemoryDocumentStore(bm25_algorithm="BM25L")
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-08-11 14:45:56 +02:00
|
|
|
def test_to_dict(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
store = InMemoryDocumentStore()
|
2023-08-11 14:45:56 +02:00
|
|
|
data = store.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
2023-08-11 14:45:56 +02:00
|
|
|
"init_parameters": {
|
|
|
|
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
|
2024-02-06 11:07:27 +01:00
|
|
|
"bm25_algorithm": "BM25L",
|
2023-08-11 14:45:56 +02:00
|
|
|
"bm25_parameters": {},
|
2023-09-07 15:44:07 +02:00
|
|
|
"embedding_similarity_function": "dot_product",
|
2024-05-31 16:44:14 +02:00
|
|
|
"index": store.index,
|
2023-08-11 14:45:56 +02:00
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_to_dict_with_custom_init_parameters(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
store = InMemoryDocumentStore(
|
2023-09-07 15:44:07 +02:00
|
|
|
bm25_tokenization_regex="custom_regex",
|
|
|
|
bm25_algorithm="BM25Plus",
|
|
|
|
bm25_parameters={"key": "value"},
|
|
|
|
embedding_similarity_function="cosine",
|
2024-05-31 16:44:14 +02:00
|
|
|
index="my_cool_index",
|
2023-08-11 14:45:56 +02:00
|
|
|
)
|
|
|
|
data = store.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
2023-08-11 14:45:56 +02:00
|
|
|
"init_parameters": {
|
|
|
|
"bm25_tokenization_regex": "custom_regex",
|
|
|
|
"bm25_algorithm": "BM25Plus",
|
|
|
|
"bm25_parameters": {"key": "value"},
|
2023-09-07 15:44:07 +02:00
|
|
|
"embedding_similarity_function": "cosine",
|
2024-05-31 16:44:14 +02:00
|
|
|
"index": "my_cool_index",
|
2023-08-11 14:45:56 +02:00
|
|
|
},
|
|
|
|
}
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
@patch("haystack.document_stores.in_memory.document_store.re")
|
2023-08-11 14:45:56 +02:00
|
|
|
def test_from_dict(self, mock_regex):
|
|
|
|
data = {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
2023-08-11 14:45:56 +02:00
|
|
|
"init_parameters": {
|
|
|
|
"bm25_tokenization_regex": "custom_regex",
|
|
|
|
"bm25_algorithm": "BM25Plus",
|
|
|
|
"bm25_parameters": {"key": "value"},
|
2024-05-31 16:44:14 +02:00
|
|
|
"index": "my_cool_index",
|
2023-08-11 14:45:56 +02:00
|
|
|
},
|
|
|
|
}
|
2023-10-17 16:15:16 +02:00
|
|
|
store = InMemoryDocumentStore.from_dict(data)
|
2023-08-11 14:45:56 +02:00
|
|
|
mock_regex.compile.assert_called_with("custom_regex")
|
|
|
|
assert store.tokenizer
|
2024-05-03 08:10:15 -04:00
|
|
|
assert store.bm25_algorithm == "BM25Plus"
|
2023-08-11 14:45:56 +02:00
|
|
|
assert store.bm25_parameters == {"key": "value"}
|
2024-05-31 16:44:14 +02:00
|
|
|
assert store.index == "my_cool_index"
|
2023-08-11 14:45:56 +02:00
|
|
|
|
2024-05-03 08:10:15 -04:00
|
|
|
def test_invalid_bm25_algorithm(self):
|
|
|
|
with pytest.raises(ValueError, match="BM25 algorithm 'invalid' is not supported"):
|
|
|
|
InMemoryDocumentStore(bm25_algorithm="invalid")
|
|
|
|
|
2023-11-28 12:30:17 +01:00
|
|
|
def test_write_documents(self, document_store):
|
|
|
|
docs = [Document(id="1")]
|
|
|
|
assert document_store.write_documents(docs) == 1
|
|
|
|
with pytest.raises(DuplicateDocumentError):
|
|
|
|
document_store.write_documents(docs)
|
2023-11-20 09:03:22 +00:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval(self, document_store: InMemoryDocumentStore):
|
2023-06-27 17:42:23 +02:00
|
|
|
# Tests if the bm25_retrieval method returns the correct document based on the input query.
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="What languages?", top_k=1)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 1
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results[0].content == "Haystack supports multiple languages"
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_empty_document_store(self, document_store: InMemoryDocumentStore, caplog):
|
2023-06-27 17:42:23 +02:00
|
|
|
caplog.set_level(logging.INFO)
|
2023-08-12 08:44:36 +02:00
|
|
|
# Tests if the bm25_retrieval method correctly returns an empty list when there are no documents in the DocumentStore.
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="How to test this?", top_k=2)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 0
|
|
|
|
assert "No documents found for BM25 retrieval. Returning empty list." in caplog.text
|
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_empty_query(self, document_store: InMemoryDocumentStore):
|
2023-06-27 17:42:23 +02:00
|
|
|
# Tests if the bm25_retrieval method returns a document when the query is an empty string.
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
2023-09-11 16:40:00 +01:00
|
|
|
with pytest.raises(ValueError, match="Query should be a non-empty string"):
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.bm25_retrieval(query="", top_k=1)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_different_top_k(self, document_store: InMemoryDocumentStore):
|
2023-06-27 17:42:23 +02:00
|
|
|
# Tests if the bm25_retrieval method correctly changes the number of returned documents
|
|
|
|
# based on the top_k parameter.
|
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world"),
|
|
|
|
Document(content="Haystack supports multiple languages"),
|
|
|
|
Document(content="Python is a popular programming language"),
|
2023-06-27 17:42:23 +02:00
|
|
|
]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
# top_k = 2
|
2024-01-12 17:50:55 +01:00
|
|
|
results = document_store.bm25_retrieval(query="language", top_k=2)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 2
|
|
|
|
|
|
|
|
# top_k = 3
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="languages", top_k=3)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 3
|
|
|
|
|
2024-05-03 08:10:15 -04:00
|
|
|
def test_bm25_plus_retrieval(self):
|
|
|
|
doc_store = InMemoryDocumentStore(bm25_algorithm="BM25Plus")
|
|
|
|
docs = [
|
|
|
|
Document(content="Hello world"),
|
|
|
|
Document(content="Haystack supports multiple languages"),
|
|
|
|
Document(content="Python is a popular programming language"),
|
|
|
|
]
|
|
|
|
doc_store.write_documents(docs)
|
|
|
|
|
|
|
|
results = doc_store.bm25_retrieval(query="language", top_k=1)
|
|
|
|
assert len(results) == 1
|
|
|
|
assert results[0].content == "Python is a popular programming language"
|
2023-11-29 19:24:25 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_two_queries(self, document_store: InMemoryDocumentStore):
|
2023-06-27 17:42:23 +02:00
|
|
|
# Tests if the bm25_retrieval method returns different documents for different queries.
|
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Javascript is a popular programming language"),
|
|
|
|
Document(content="Java is a popular programming language"),
|
|
|
|
Document(content="Python is a popular programming language"),
|
|
|
|
Document(content="Ruby is a popular programming language"),
|
|
|
|
Document(content="PHP is a popular programming language"),
|
2023-06-27 17:42:23 +02:00
|
|
|
]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="Java", top_k=1)
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results[0].content == "Java is a popular programming language"
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="Python", top_k=1)
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results[0].content == "Python is a popular programming language"
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
# Test a query, add a new document and make sure results are appropriately updated
|
2023-11-29 19:24:25 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_updated_docs(self, document_store: InMemoryDocumentStore):
|
2023-06-27 17:42:23 +02:00
|
|
|
# Tests if the bm25_retrieval method correctly updates the retrieved documents when new
|
2023-08-12 08:44:36 +02:00
|
|
|
# documents are added to the DocumentStore.
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(content="Hello world")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="Python", top_k=1)
|
2024-01-12 17:50:55 +01:00
|
|
|
assert len(results) == 0
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents([Document(content="Python is a popular programming language")])
|
|
|
|
results = document_store.bm25_retrieval(query="Python", top_k=1)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 1
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results[0].content == "Python is a popular programming language"
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents([Document(content="Java is a popular programming language")])
|
|
|
|
results = document_store.bm25_retrieval(query="Python", top_k=1)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 1
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results[0].content == "Python is a popular programming language"
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_scale_score(self, document_store: InMemoryDocumentStore):
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(content="Python programming"), Document(content="Java programming")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
2023-06-27 17:42:23 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
results1 = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=True)
|
2023-06-27 17:42:23 +02:00
|
|
|
# Confirm that score is scaled between 0 and 1
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results1[0].score is not None
|
|
|
|
assert 0.0 <= results1[0].score <= 1.0
|
2023-06-27 17:42:23 +02:00
|
|
|
|
|
|
|
# Same query, different scale, scores differ when not scaled
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="Python", top_k=1, scale_score=False)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert results[0].score != results1[0].score
|
|
|
|
|
2024-05-03 08:10:15 -04:00
|
|
|
def test_bm25_retrieval_with_non_scaled_BM25Okapi(self):
|
2024-02-06 11:07:27 +01:00
|
|
|
# Highly repetitive documents make BM25Okapi return negative scores, which should not be filtered if the
|
|
|
|
# scores are not scaled
|
|
|
|
docs = [
|
|
|
|
Document(
|
|
|
|
content="""Use pip to install a basic version of Haystack's latest release: pip install
|
|
|
|
farm-haystack. All the core Haystack components live in the haystack repo. But there's also the
|
|
|
|
haystack-extras repo which contains components that are not as widely used, and you need to
|
|
|
|
install them separately."""
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="""Use pip to install a basic version of Haystack's latest release: pip install
|
|
|
|
farm-haystack[inference]. All the core Haystack components live in the haystack repo. But there's
|
|
|
|
also the haystack-extras repo which contains components that are not as widely used, and you need
|
|
|
|
to install them separately."""
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
content="""Use pip to install only the Haystack 2.0 code: pip install haystack-ai. The haystack-ai
|
|
|
|
package is built on the main branch which is an unstable beta version, but it's useful if you want
|
|
|
|
to try the new features as soon as they are merged."""
|
|
|
|
),
|
|
|
|
]
|
2024-05-03 08:10:15 -04:00
|
|
|
document_store = InMemoryDocumentStore(bm25_algorithm="BM25Okapi")
|
2024-02-06 11:07:27 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
|
|
|
|
results1 = document_store.bm25_retrieval(query="Haystack installation", top_k=10, scale_score=False)
|
|
|
|
assert len(results1) == 3
|
|
|
|
assert all(res.score < 0.0 for res in results1)
|
|
|
|
|
|
|
|
results2 = document_store.bm25_retrieval(query="Haystack installation", top_k=10, scale_score=True)
|
|
|
|
assert len(results2) == 3
|
|
|
|
assert all(0.0 <= res.score <= 1.0 for res in results2)
|
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_table_content(self, document_store: InMemoryDocumentStore):
|
2023-06-27 17:42:23 +02:00
|
|
|
# Tests if the bm25_retrieval method correctly returns a dataframe when the content_type is table.
|
|
|
|
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(dataframe=table_content), Document(content="Gardening"), Document(content="Bird watching")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="Java", top_k=1)
|
2023-06-27 17:42:23 +02:00
|
|
|
assert len(results) == 1
|
2023-09-11 16:40:00 +01:00
|
|
|
|
|
|
|
df = results[0].dataframe
|
2023-06-27 17:42:23 +02:00
|
|
|
assert isinstance(df, pd.DataFrame)
|
|
|
|
assert df.equals(table_content)
|
2023-09-07 15:44:07 +02:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_text_and_table_content(self, document_store: InMemoryDocumentStore, caplog):
|
2023-09-11 16:40:00 +01:00
|
|
|
table_content = pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web Development"]})
|
2023-10-31 12:44:04 +01:00
|
|
|
document = Document(content="Gardening", dataframe=table_content)
|
2023-09-11 16:40:00 +01:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Python"),
|
|
|
|
Document(content="Bird Watching"),
|
|
|
|
Document(content="Gardening"),
|
|
|
|
Document(content="Java"),
|
2024-05-03 08:10:15 -04:00
|
|
|
document,
|
2023-09-11 16:40:00 +01:00
|
|
|
]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="Gardening", top_k=2)
|
2023-11-16 17:17:43 +01:00
|
|
|
assert document.id in [d.id for d in results]
|
2023-09-11 16:40:00 +01:00
|
|
|
assert "both text and dataframe content" in caplog.text
|
2023-11-20 17:41:48 +01:00
|
|
|
results = document_store.bm25_retrieval(query="Python", top_k=2)
|
2023-11-16 17:17:43 +01:00
|
|
|
assert document.id not in [d.id for d in results]
|
2023-09-11 16:40:00 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_default_filter_for_text_and_dataframes(self, document_store: InMemoryDocumentStore):
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(), Document(content="Gardening"), Document(content="Bird watching")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="doesn't matter, top_k is 10", top_k=10)
|
2024-01-12 17:50:55 +01:00
|
|
|
assert len(results) == 0
|
2023-09-11 16:40:00 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_filters(self, document_store: InMemoryDocumentStore):
|
2024-01-12 17:50:55 +01:00
|
|
|
selected_document = Document(content="Java is, well...", meta={"selected": True})
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(), selected_document, Document(content="Bird watching")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"selected": True})
|
2023-11-16 17:17:43 +01:00
|
|
|
assert len(results) == 1
|
|
|
|
assert results[0].id == selected_document.id
|
2023-09-11 16:40:00 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_filters_keeps_default_filters(self, document_store: InMemoryDocumentStore):
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(meta={"selected": True}), Document(content="Gardening"), Document(content="Bird watching")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"selected": True})
|
2023-10-31 12:44:04 +01:00
|
|
|
assert len(results) == 0
|
2023-09-11 16:40:00 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_filters_on_text_or_dataframe(self, document_store: InMemoryDocumentStore):
|
2023-09-11 16:40:00 +01:00
|
|
|
document = Document(dataframe=pd.DataFrame({"language": ["Python", "Java"], "use": ["Data Science", "Web"]}))
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(), Document(content="Gardening"), Document(content="Bird watching"), document]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
|
|
|
results = document_store.bm25_retrieval(query="Java", top_k=10, filters={"content": None})
|
2023-11-16 17:17:43 +01:00
|
|
|
assert len(results) == 1
|
|
|
|
assert results[0].id == document.id
|
2023-09-11 16:40:00 +01:00
|
|
|
|
2023-11-20 17:41:48 +01:00
|
|
|
def test_bm25_retrieval_with_documents_with_mixed_content(self, document_store: InMemoryDocumentStore):
|
2024-01-12 17:50:55 +01:00
|
|
|
double_document = Document(content="Gardening is a hobby", embedding=[1.0, 2.0, 3.0])
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(embedding=[1.0, 2.0, 3.0]), double_document, Document(content="Bird watching")]
|
2023-11-20 17:41:48 +01:00
|
|
|
document_store.write_documents(docs)
|
2024-01-12 17:50:55 +01:00
|
|
|
results = document_store.bm25_retrieval(query="Gardening", top_k=10, filters={"embedding": {"$not": None}})
|
2023-11-16 17:17:43 +01:00
|
|
|
assert len(results) == 1
|
|
|
|
assert results[0].id == double_document.id
|
2023-09-11 16:40:00 +01:00
|
|
|
|
2023-09-07 15:44:07 +02:00
|
|
|
def test_embedding_retrieval(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
2023-09-07 15:44:07 +02:00
|
|
|
# Tests if the embedding retrieval method returns the correct document based on the input query embedding.
|
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
docstore.write_documents(docs)
|
|
|
|
results = docstore.embedding_retrieval(
|
2023-10-23 12:26:05 +02:00
|
|
|
query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, filters={}, scale_score=False
|
2023-09-07 15:44:07 +02:00
|
|
|
)
|
|
|
|
assert len(results) == 1
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results[0].content == "Haystack supports multiple languages"
|
2023-09-07 15:44:07 +02:00
|
|
|
|
|
|
|
def test_embedding_retrieval_invalid_query(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-09-07 15:44:07 +02:00
|
|
|
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
|
|
|
|
docstore.embedding_retrieval(query_embedding=[])
|
|
|
|
with pytest.raises(ValueError, match="query_embedding should be a non-empty list of floats"):
|
2023-10-31 12:44:04 +01:00
|
|
|
docstore.embedding_retrieval(query_embedding=["invalid", "list", "of", "strings"]) # type: ignore
|
2023-09-07 15:44:07 +02:00
|
|
|
|
|
|
|
def test_embedding_retrieval_no_embeddings(self, caplog):
|
|
|
|
caplog.set_level(logging.WARNING)
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(content="Hello world"), Document(content="Haystack supports multiple languages")]
|
2023-09-07 15:44:07 +02:00
|
|
|
docstore.write_documents(docs)
|
2023-10-23 12:26:05 +02:00
|
|
|
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
2023-09-07 15:44:07 +02:00
|
|
|
assert len(results) == 0
|
|
|
|
assert "No Documents found with embeddings. Returning empty list." in caplog.text
|
|
|
|
|
|
|
|
def test_embedding_retrieval_some_documents_wo_embeddings(self, caplog):
|
|
|
|
caplog.set_level(logging.INFO)
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="Haystack supports multiple languages"),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
docstore.write_documents(docs)
|
2023-10-23 12:26:05 +02:00
|
|
|
docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
2023-09-07 15:44:07 +02:00
|
|
|
assert "Skipping some Documents that don't have an embedding." in caplog.text
|
|
|
|
|
|
|
|
def test_embedding_retrieval_documents_different_embedding_sizes(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
docstore.write_documents(docs)
|
|
|
|
|
|
|
|
with pytest.raises(DocumentStoreError, match="The embedding size of all Documents should be the same."):
|
2023-10-23 12:26:05 +02:00
|
|
|
docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1])
|
2023-09-07 15:44:07 +02:00
|
|
|
|
|
|
|
def test_embedding_retrieval_query_documents_different_embedding_sizes(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-10-31 12:44:04 +01:00
|
|
|
docs = [Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])]
|
2023-09-07 15:44:07 +02:00
|
|
|
docstore.write_documents(docs)
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
DocumentStoreError,
|
|
|
|
match="The embedding size of the query should be the same as the embedding size of the Documents.",
|
|
|
|
):
|
|
|
|
docstore.embedding_retrieval(query_embedding=[0.1, 0.1])
|
|
|
|
|
|
|
|
def test_embedding_retrieval_with_different_top_k(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
|
|
|
Document(content="Python is a popular programming language", embedding=[0.5, 0.5, 0.5, 0.5]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
docstore.write_documents(docs)
|
|
|
|
|
2023-10-23 12:26:05 +02:00
|
|
|
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2)
|
2023-09-07 15:44:07 +02:00
|
|
|
assert len(results) == 2
|
|
|
|
|
2023-10-23 12:26:05 +02:00
|
|
|
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3)
|
2023-09-07 15:44:07 +02:00
|
|
|
assert len(results) == 3
|
|
|
|
|
|
|
|
def test_embedding_retrieval_with_scale_score(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore()
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
|
|
|
Document(content="Python is a popular programming language", embedding=[0.5, 0.5, 0.5, 0.5]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
docstore.write_documents(docs)
|
|
|
|
|
2023-10-23 12:26:05 +02:00
|
|
|
results1 = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, scale_score=True)
|
2023-09-07 15:44:07 +02:00
|
|
|
# Confirm that score is scaled between 0 and 1
|
2023-10-31 12:44:04 +01:00
|
|
|
assert results1[0].score is not None
|
|
|
|
assert 0.0 <= results1[0].score <= 1.0
|
2023-09-07 15:44:07 +02:00
|
|
|
|
|
|
|
# Same query, different scale, scores differ when not scaled
|
2023-10-23 12:26:05 +02:00
|
|
|
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, scale_score=False)
|
2023-09-07 15:44:07 +02:00
|
|
|
assert results[0].score != results1[0].score
|
|
|
|
|
|
|
|
def test_embedding_retrieval_return_embedding(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4]),
|
|
|
|
Document(content="Haystack supports multiple languages", embedding=[1.0, 1.0, 1.0, 1.0]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
docstore.write_documents(docs)
|
|
|
|
|
2023-10-23 12:26:05 +02:00
|
|
|
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, return_embedding=False)
|
2023-09-07 15:44:07 +02:00
|
|
|
assert results[0].embedding is None
|
|
|
|
|
2023-10-23 12:26:05 +02:00
|
|
|
results = docstore.embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=1, return_embedding=True)
|
|
|
|
assert results[0].embedding == [1.0, 1.0, 1.0, 1.0]
|
2023-09-07 15:44:07 +02:00
|
|
|
|
|
|
|
def test_compute_cosine_similarity_scores(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore(embedding_similarity_function="cosine")
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
|
|
|
|
Document(content="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
scores = docstore._compute_query_embedding_similarity_scores(
|
2023-10-23 12:26:05 +02:00
|
|
|
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
2023-09-07 15:44:07 +02:00
|
|
|
)
|
|
|
|
assert scores == [0.5, 1.0]
|
|
|
|
|
|
|
|
def test_compute_dot_product_similarity_scores(self):
|
2023-10-17 16:15:16 +02:00
|
|
|
docstore = InMemoryDocumentStore(embedding_similarity_function="dot_product")
|
2023-09-07 15:44:07 +02:00
|
|
|
docs = [
|
2023-10-31 12:44:04 +01:00
|
|
|
Document(content="Document 1", embedding=[1.0, 0.0, 0.0, 0.0]),
|
|
|
|
Document(content="Document 2", embedding=[1.0, 1.0, 1.0, 1.0]),
|
2023-09-07 15:44:07 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
scores = docstore._compute_query_embedding_similarity_scores(
|
2023-10-23 12:26:05 +02:00
|
|
|
embedding=[0.1, 0.1, 0.1, 0.1], documents=docs, scale_score=False
|
2023-09-07 15:44:07 +02:00
|
|
|
)
|
|
|
|
assert scores == [0.1, 0.4]
|
2024-05-31 16:44:14 +02:00
|
|
|
|
|
|
|
def test_multiple_document_stores_using_same_index(self):
|
|
|
|
index = "test_multiple_document_stores_using_same_index"
|
|
|
|
document_store_1 = InMemoryDocumentStore(index=index)
|
|
|
|
document_store_2 = InMemoryDocumentStore(index=index)
|
|
|
|
|
|
|
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
|
|
|
|
|
|
|
|
doc_1 = Document(content="Hello world")
|
|
|
|
document_store_1.write_documents([doc_1])
|
|
|
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
|
|
|
|
|
|
|
|
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1]
|
|
|
|
|
|
|
|
doc_2 = Document(content="Hello another world")
|
|
|
|
document_store_2.write_documents([doc_2])
|
|
|
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 2
|
|
|
|
|
|
|
|
assert document_store_1.filter_documents() == document_store_2.filter_documents() == [doc_1, doc_2]
|
|
|
|
|
|
|
|
document_store_1.delete_documents([doc_2.id])
|
|
|
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 1
|
|
|
|
|
|
|
|
document_store_2.delete_documents([doc_1.id])
|
|
|
|
assert document_store_1.count_documents() == document_store_2.count_documents() == 0
|