mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
* importing files from experimental * linting + tests * fixing integrations tests * adding release notes * fixing imports * adding query component * adding docs to docusaurus * Update docs/pydoc/config_docusaurus/query_api.yml Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update haystack/components/query/query_expander.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update releasenotes/notes/adding-QueryExpander-MultiQueryRetriever-88c4847894ea1fd0.yaml Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * fixing code examples * adding extra unit tests to assert deduplication is working * fixing and increasing QueryExpander tets --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
290 lines
13 KiB
Python
290 lines
13 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from haystack import Document, Pipeline
|
|
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
|
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
from haystack.components.query import QueryExpander
|
|
from haystack.components.retrievers import InMemoryEmbeddingRetriever, MultiQueryEmbeddingRetriever
|
|
from haystack.components.writers import DocumentWriter
|
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
from haystack.document_stores.types import DuplicatePolicy
|
|
|
|
|
|
class TestMultiQueryEmbeddingRetriever:
|
|
@pytest.fixture
|
|
def sample_documents(self):
|
|
return [
|
|
Document(
|
|
content="Renewable energy is energy that is collected from renewable resources.",
|
|
meta={"category": None},
|
|
),
|
|
Document(
|
|
content="Solar energy is a type of green energy that is harnessed from the sun.",
|
|
meta={"category": "solar"},
|
|
),
|
|
Document(
|
|
content="Wind energy is another type of green energy that is generated by wind turbines",
|
|
meta={"category": "wind"},
|
|
),
|
|
Document(
|
|
content="Hydropower is a form of renewable energy using the flow of water to generate electricity.",
|
|
meta={"category": "hydro"},
|
|
),
|
|
Document(
|
|
content="Geothermal energy is heat that comes from the sub-surface of the earth.",
|
|
meta={"category": "geo"},
|
|
),
|
|
Document(
|
|
content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
|
|
meta={"category": "fossil"},
|
|
),
|
|
Document(
|
|
content="Nuclear energy is produced through nuclear reactions, typically using uranium or "
|
|
"plutonium as fuel.",
|
|
meta={"category": "nuclear"},
|
|
),
|
|
]
|
|
|
|
@pytest.fixture
|
|
def document_store_with_embeddings(self, sample_documents):
|
|
"""Create a document store populated with embedded documents."""
|
|
document_store = InMemoryDocumentStore()
|
|
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
|
doc_embedder.warm_up()
|
|
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
|
|
|
|
embedded_docs = doc_embedder.run(sample_documents)["documents"]
|
|
doc_writer.run(documents=embedded_docs)
|
|
return document_store
|
|
|
|
@pytest.fixture
|
|
def mock_query_embedder(self):
|
|
with patch(
|
|
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
|
) as mock_text_embedder:
|
|
mock_model = MagicMock()
|
|
mock_text_embedder.return_value = mock_model
|
|
|
|
def mock_encode(
|
|
texts, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs
|
|
): # noqa E501
|
|
return [np.ones(384).tolist() for _ in texts]
|
|
|
|
mock_model.encode = mock_encode
|
|
embedder = SentenceTransformersTextEmbedder(model="mock-model", progress_bar=False)
|
|
|
|
def mock_run(text):
|
|
embedding = np.ones(384).tolist()
|
|
return {"embedding": embedding}
|
|
|
|
embedder.run = mock_run
|
|
embedder.warm_up()
|
|
return embedder
|
|
|
|
def test_init_with_default_parameters(self, mock_query_embedder):
|
|
embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
|
|
query_embedder = mock_query_embedder
|
|
|
|
retriever = MultiQueryEmbeddingRetriever(retriever=embedding_retriever, query_embedder=query_embedder)
|
|
|
|
assert retriever.retriever == embedding_retriever
|
|
assert retriever.query_embedder == mock_query_embedder
|
|
assert retriever.max_workers == 3
|
|
|
|
def test_init_with_custom_parameters(self, mock_query_embedder):
|
|
embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
|
|
query_embedder = mock_query_embedder
|
|
retriever = MultiQueryEmbeddingRetriever(
|
|
retriever=embedding_retriever, query_embedder=query_embedder, max_workers=2
|
|
)
|
|
|
|
assert retriever.retriever == embedding_retriever
|
|
assert retriever.query_embedder == mock_query_embedder
|
|
assert retriever.max_workers == 2
|
|
|
|
def test_run_with_empty_queries(self, mock_query_embedder):
|
|
multi_retriever = MultiQueryEmbeddingRetriever(
|
|
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
|
|
query_embedder=mock_query_embedder,
|
|
)
|
|
|
|
result = multi_retriever.run(queries=[])
|
|
|
|
assert "documents" in result
|
|
assert result["documents"] == []
|
|
|
|
def test_run_with_empty_results(self, mock_query_embedder):
|
|
mock_query_embedder.run.return_value = {"embedding": [0.1, 0.2, 0.3]}
|
|
multi_retriever = MultiQueryEmbeddingRetriever(
|
|
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
|
|
query_embedder=mock_query_embedder,
|
|
)
|
|
result = multi_retriever.run(queries=["query"])
|
|
assert "documents" in result
|
|
assert result["documents"] == []
|
|
|
|
def test_to_dict(self):
|
|
multi_retriever = MultiQueryEmbeddingRetriever(
|
|
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
|
|
query_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
|
|
max_workers=2,
|
|
)
|
|
|
|
result = multi_retriever.to_dict()
|
|
|
|
assert "type" in result
|
|
assert "init_parameters" in result
|
|
assert result["init_parameters"]["max_workers"] == 2
|
|
assert "retriever" in result["init_parameters"]
|
|
assert "query_embedder" in result["init_parameters"]
|
|
|
|
def test_from_dict(self):
|
|
data = {
|
|
"type": "haystack.components.retrievers.multi_query_embedding_retriever.MultiQueryEmbeddingRetriever", # noqa E501
|
|
"init_parameters": {
|
|
"retriever": {
|
|
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
|
|
"init_parameters": {
|
|
"document_store": {
|
|
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
|
"init_parameters": {
|
|
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
|
|
"bm25_algorithm": "BM25L",
|
|
"bm25_parameters": {},
|
|
"embedding_similarity_function": "dot_product",
|
|
"index": "4bb5369d-779f-487b-9c16-3c40f503438b",
|
|
# 'return_embedding': True # ToDo: investigate why this fails
|
|
},
|
|
},
|
|
"filters": None,
|
|
"top_k": 10,
|
|
"scale_score": False,
|
|
"return_embedding": False,
|
|
"filter_policy": "replace",
|
|
},
|
|
},
|
|
"query_embedder": {
|
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa E501
|
|
"init_parameters": {
|
|
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
|
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
|
|
"prefix": "",
|
|
"suffix": "",
|
|
"batch_size": 32,
|
|
"progress_bar": True,
|
|
"normalize_embeddings": False,
|
|
"trust_remote_code": False,
|
|
"local_files_only": False,
|
|
"truncate_dim": None,
|
|
"model_kwargs": None,
|
|
"tokenizer_kwargs": None,
|
|
"config_kwargs": None,
|
|
"precision": "float32",
|
|
"encode_kwargs": None,
|
|
"backend": "torch",
|
|
},
|
|
},
|
|
"max_workers": 2,
|
|
},
|
|
}
|
|
|
|
result = MultiQueryEmbeddingRetriever.from_dict(data)
|
|
|
|
assert isinstance(result, MultiQueryEmbeddingRetriever)
|
|
assert result.max_workers == 2
|
|
|
|
def test_deduplication_with_overlapping_results(self, mock_query_embedder):
|
|
doc1 = Document(content="Solar energy is renewable", id="doc1")
|
|
doc1.score = 0.9
|
|
doc2 = Document(content="Wind energy is clean", id="doc2")
|
|
doc2.score = 0.8
|
|
# same content as doc1 w/ different score
|
|
doc3 = Document(content="Solar energy is renewable", id="doc3")
|
|
doc3.score = 0.7
|
|
|
|
# mocked retriever
|
|
mock_retriever = MagicMock()
|
|
call_count = 0
|
|
|
|
def mock_retriever_run(query_embedding, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return {"documents": [doc1, doc2]}
|
|
else:
|
|
return {"documents": [doc3, doc2]}
|
|
|
|
mock_retriever.run = mock_retriever_run
|
|
multi_retriever = MultiQueryEmbeddingRetriever(
|
|
retriever=mock_retriever, query_embedder=mock_query_embedder, max_workers=1
|
|
)
|
|
result = multi_retriever.run(queries=["query1", "query2"])
|
|
|
|
assert "documents" in result
|
|
assert len(result["documents"]) == 2 # Only 2 unique documents (doc1/doc3 and doc2)
|
|
|
|
contents = [doc.content for doc in result["documents"]]
|
|
assert contents.count("Solar energy is renewable") == 1
|
|
assert contents.count("Wind energy is clean") == 1
|
|
|
|
@pytest.mark.integration
|
|
def test_run_with_filters(self, document_store_with_embeddings):
|
|
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings)
|
|
query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
|
multi_retriever = MultiQueryEmbeddingRetriever(retriever=in_memory_retriever, query_embedder=query_embedder)
|
|
multi_retriever.warm_up()
|
|
kwargs = {"filters": {"field": "category", "operator": "==", "value": "solar"}}
|
|
result = multi_retriever.run(["energy"], kwargs)
|
|
assert "documents" in result
|
|
assert all(doc.meta.get("category") == "solar" for doc in result["documents"])
|
|
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
)
|
|
@pytest.mark.integration
|
|
def test_pipeline_integration(self, document_store_with_embeddings):
|
|
expander = QueryExpander(
|
|
chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), n_expansions=3, include_original_query=True
|
|
)
|
|
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings)
|
|
query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
|
multiquery_retriever = MultiQueryEmbeddingRetriever(
|
|
retriever=in_memory_retriever, query_embedder=query_embedder, max_workers=3
|
|
)
|
|
|
|
pipeline = Pipeline()
|
|
pipeline.add_component("query_expander", expander)
|
|
pipeline.add_component("multiquery_retriever", multiquery_retriever)
|
|
pipeline.connect("query_expander.queries", "multiquery_retriever.queries")
|
|
|
|
data = {
|
|
"query_expander": {"query": "green energy sources"},
|
|
"multiquery_retriever": {"retriever_kwargs": {"top_k": 3}},
|
|
}
|
|
results = pipeline.run(data=data, include_outputs_from={"query_expander", "multiquery_retriever"})
|
|
|
|
assert "multiquery_retriever" in results
|
|
assert "documents" in results["multiquery_retriever"]
|
|
assert len(results["multiquery_retriever"]["documents"]) > 0
|
|
assert "query_expander" in results
|
|
assert "queries" in results["query_expander"]
|
|
assert len(results["query_expander"]["queries"]) == 4
|
|
|
|
# assert that documents are sorted by score (highest first)
|
|
scores = [doc.score for doc in results["multiquery_retriever"]["documents"] if doc.score is not None]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
# assert there are not duplicates
|
|
contents = [doc.content for doc in results["multiquery_retriever"]["documents"]]
|
|
assert len(contents) == len(set(contents))
|