haystack/test/components/retrievers/test_multi_query_embedding_retriever.py
David S. Batista 763174ede6
feat: adding QueryExpander, MultiQueryEmbeddingRetriever and MultiQueryTextRetriever (#10126)
* importing files from experimental

* linting + tests

* fixing integrations tests

* adding release notes

* fixing imports

* adding query component

* adding docs to docusaurus

* Update docs/pydoc/config_docusaurus/query_api.yml

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Update haystack/components/query/query_expander.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Update releasenotes/notes/adding-QueryExpander-MultiQueryRetriever-88c4847894ea1fd0.yaml

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* fixing code examples

* adding extra unit tests to assert deduplication is working

* fixing and increasing QueryExpander tets

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
2025-11-25 10:22:42 +01:00

290 lines
13 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.query import QueryExpander
from haystack.components.retrievers import InMemoryEmbeddingRetriever, MultiQueryEmbeddingRetriever
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
class TestMultiQueryEmbeddingRetriever:
@pytest.fixture
def sample_documents(self):
return [
Document(
content="Renewable energy is energy that is collected from renewable resources.",
meta={"category": None},
),
Document(
content="Solar energy is a type of green energy that is harnessed from the sun.",
meta={"category": "solar"},
),
Document(
content="Wind energy is another type of green energy that is generated by wind turbines",
meta={"category": "wind"},
),
Document(
content="Hydropower is a form of renewable energy using the flow of water to generate electricity.",
meta={"category": "hydro"},
),
Document(
content="Geothermal energy is heat that comes from the sub-surface of the earth.",
meta={"category": "geo"},
),
Document(
content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
meta={"category": "fossil"},
),
Document(
content="Nuclear energy is produced through nuclear reactions, typically using uranium or "
"plutonium as fuel.",
meta={"category": "nuclear"},
),
]
@pytest.fixture
def document_store_with_embeddings(self, sample_documents):
"""Create a document store populated with embedded documents."""
document_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder.warm_up()
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
embedded_docs = doc_embedder.run(sample_documents)["documents"]
doc_writer.run(documents=embedded_docs)
return document_store
@pytest.fixture
def mock_query_embedder(self):
with patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
) as mock_text_embedder:
mock_model = MagicMock()
mock_text_embedder.return_value = mock_model
def mock_encode(
texts, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs
): # noqa E501
return [np.ones(384).tolist() for _ in texts]
mock_model.encode = mock_encode
embedder = SentenceTransformersTextEmbedder(model="mock-model", progress_bar=False)
def mock_run(text):
embedding = np.ones(384).tolist()
return {"embedding": embedding}
embedder.run = mock_run
embedder.warm_up()
return embedder
def test_init_with_default_parameters(self, mock_query_embedder):
embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
query_embedder = mock_query_embedder
retriever = MultiQueryEmbeddingRetriever(retriever=embedding_retriever, query_embedder=query_embedder)
assert retriever.retriever == embedding_retriever
assert retriever.query_embedder == mock_query_embedder
assert retriever.max_workers == 3
def test_init_with_custom_parameters(self, mock_query_embedder):
embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
query_embedder = mock_query_embedder
retriever = MultiQueryEmbeddingRetriever(
retriever=embedding_retriever, query_embedder=query_embedder, max_workers=2
)
assert retriever.retriever == embedding_retriever
assert retriever.query_embedder == mock_query_embedder
assert retriever.max_workers == 2
def test_run_with_empty_queries(self, mock_query_embedder):
multi_retriever = MultiQueryEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
query_embedder=mock_query_embedder,
)
result = multi_retriever.run(queries=[])
assert "documents" in result
assert result["documents"] == []
def test_run_with_empty_results(self, mock_query_embedder):
mock_query_embedder.run.return_value = {"embedding": [0.1, 0.2, 0.3]}
multi_retriever = MultiQueryEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
query_embedder=mock_query_embedder,
)
result = multi_retriever.run(queries=["query"])
assert "documents" in result
assert result["documents"] == []
def test_to_dict(self):
multi_retriever = MultiQueryEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
query_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
max_workers=2,
)
result = multi_retriever.to_dict()
assert "type" in result
assert "init_parameters" in result
assert result["init_parameters"]["max_workers"] == 2
assert "retriever" in result["init_parameters"]
assert "query_embedder" in result["init_parameters"]
def test_from_dict(self):
data = {
"type": "haystack.components.retrievers.multi_query_embedding_retriever.MultiQueryEmbeddingRetriever", # noqa E501
"init_parameters": {
"retriever": {
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": "4bb5369d-779f-487b-9c16-3c40f503438b",
# 'return_embedding': True # ToDo: investigate why this fails
},
},
"filters": None,
"top_k": 10,
"scale_score": False,
"return_embedding": False,
"filter_policy": "replace",
},
},
"query_embedder": {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa E501
"init_parameters": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"local_files_only": False,
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"config_kwargs": None,
"precision": "float32",
"encode_kwargs": None,
"backend": "torch",
},
},
"max_workers": 2,
},
}
result = MultiQueryEmbeddingRetriever.from_dict(data)
assert isinstance(result, MultiQueryEmbeddingRetriever)
assert result.max_workers == 2
def test_deduplication_with_overlapping_results(self, mock_query_embedder):
doc1 = Document(content="Solar energy is renewable", id="doc1")
doc1.score = 0.9
doc2 = Document(content="Wind energy is clean", id="doc2")
doc2.score = 0.8
# same content as doc1 w/ different score
doc3 = Document(content="Solar energy is renewable", id="doc3")
doc3.score = 0.7
# mocked retriever
mock_retriever = MagicMock()
call_count = 0
def mock_retriever_run(query_embedding, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return {"documents": [doc1, doc2]}
else:
return {"documents": [doc3, doc2]}
mock_retriever.run = mock_retriever_run
multi_retriever = MultiQueryEmbeddingRetriever(
retriever=mock_retriever, query_embedder=mock_query_embedder, max_workers=1
)
result = multi_retriever.run(queries=["query1", "query2"])
assert "documents" in result
assert len(result["documents"]) == 2 # Only 2 unique documents (doc1/doc3 and doc2)
contents = [doc.content for doc in result["documents"]]
assert contents.count("Solar energy is renewable") == 1
assert contents.count("Wind energy is clean") == 1
@pytest.mark.integration
def test_run_with_filters(self, document_store_with_embeddings):
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings)
query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
multi_retriever = MultiQueryEmbeddingRetriever(retriever=in_memory_retriever, query_embedder=query_embedder)
multi_retriever.warm_up()
kwargs = {"filters": {"field": "category", "operator": "==", "value": "solar"}}
result = multi_retriever.run(["energy"], kwargs)
assert "documents" in result
assert all(doc.meta.get("category") == "solar" for doc in result["documents"])
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_pipeline_integration(self, document_store_with_embeddings):
expander = QueryExpander(
chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), n_expansions=3, include_original_query=True
)
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings)
query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
multiquery_retriever = MultiQueryEmbeddingRetriever(
retriever=in_memory_retriever, query_embedder=query_embedder, max_workers=3
)
pipeline = Pipeline()
pipeline.add_component("query_expander", expander)
pipeline.add_component("multiquery_retriever", multiquery_retriever)
pipeline.connect("query_expander.queries", "multiquery_retriever.queries")
data = {
"query_expander": {"query": "green energy sources"},
"multiquery_retriever": {"retriever_kwargs": {"top_k": 3}},
}
results = pipeline.run(data=data, include_outputs_from={"query_expander", "multiquery_retriever"})
assert "multiquery_retriever" in results
assert "documents" in results["multiquery_retriever"]
assert len(results["multiquery_retriever"]["documents"]) > 0
assert "query_expander" in results
assert "queries" in results["query_expander"]
assert len(results["query_expander"]["queries"]) == 4
# assert that documents are sorted by score (highest first)
scores = [doc.score for doc in results["multiquery_retriever"]["documents"] if doc.score is not None]
assert scores == sorted(scores, reverse=True)
# assert there are not duplicates
contents = [doc.content for doc in results["multiquery_retriever"]["documents"]]
assert len(contents) == len(set(contents))