haystack/test/components/query/test_query_expander.py
David S. Batista 763174ede6
feat: adding QueryExpander, MultiQueryEmbeddingRetriever and MultiQueryTextRetriever (#10126)
* importing files from experimental

* linting + tests

* fixing integrations tests

* adding release notes

* fixing imports

* adding query component

* adding docs to docusaurus

* Update docs/pydoc/config_docusaurus/query_api.yml

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Update haystack/components/query/query_expander.py

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* Update releasenotes/notes/adding-QueryExpander-MultiQueryRetriever-88c4847894ea1fd0.yaml

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>

* fixing code examples

* adding extra unit tests to assert deduplication is working

* fixing and increasing QueryExpander tets

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
2025-11-25 10:22:42 +01:00

401 lines
16 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from unittest.mock import Mock
import pytest
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.components.query.query_expander import DEFAULT_PROMPT_TEMPLATE, QueryExpander
from haystack.dataclasses.chat_message import ChatMessage
@pytest.fixture
def mock_chat_generator():
mock_generator = Mock(spec=OpenAIChatGenerator)
return mock_generator
@pytest.fixture
def mock_chat_generator_with_warm_up():
mock_generator = Mock(spec=OpenAIChatGenerator)
mock_generator.warm_up = lambda: None
return mock_generator
class TestQueryExpander:
def test_init_default_generator(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
assert expander.n_expansions == 4
assert expander.include_original_query is True
assert isinstance(expander.chat_generator, OpenAIChatGenerator)
assert expander.chat_generator.model == "gpt-4.1-mini"
assert expander._prompt_builder is not None
def test_init_custom_generator(self, mock_chat_generator):
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=3)
assert expander.n_expansions == 3
assert expander.chat_generator is mock_chat_generator
def test_run_warm_up(self, mock_chat_generator_with_warm_up):
expander = QueryExpander(chat_generator=mock_chat_generator_with_warm_up)
mock_chat_generator_with_warm_up.run.return_value = {"queries": ["test query"]}
expander.warm_up()
expander.run("test query")
assert expander._is_warmed_up is True
assert expander.run("test query") == {"queries": ["test query"]}
def test_warm_up(self, mock_chat_generator):
expander = QueryExpander(chat_generator=mock_chat_generator)
expander.warm_up()
assert expander._is_warmed_up is True
def test_init_negative_expansions_raises_error(self):
with pytest.raises(ValueError, match="n_expansions must be positive"):
QueryExpander(n_expansions=-1)
def test_init_zero_expansions_raises_error(self):
with pytest.raises(ValueError, match="n_expansions must be positive"):
QueryExpander(n_expansions=0)
def test_init_custom_prompt_template(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
custom_template = "Custom template: {{ query }} with {{ n_expansions }} expansions"
expander = QueryExpander(prompt_template=custom_template)
assert expander.prompt_template == custom_template
def test_run_negative_expansions_raises_error(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
with pytest.raises(ValueError, match="n_expansions must be positive"):
expander.run("test query", n_expansions=-1)
def test_run_zero_expansions_raises_error(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander(n_expansions=4)
expander.warm_up()
with pytest.raises(ValueError, match="n_expansions must be positive"):
expander.run("test query", n_expansions=0)
def test_run_with_runtime_n_expansions_override(self, mock_chat_generator):
mock_chat_generator.run.return_value = {
"replies": [ChatMessage.from_assistant('{"queries": ["alt1", "alt2"]}')]
}
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=4, include_original_query=False)
expander.warm_up()
result = expander.run("test query", n_expansions=2)
# should request 2 expansions
call_args = mock_chat_generator.run.call_args[1]["messages"][0].text
assert "2" in call_args
assert len(result["queries"]) == 2
assert result["queries"] == ["alt1", "alt2"]
def test_run_successful_expansion(self, mock_chat_generator):
mock_chat_generator.run.return_value = {
"replies": [
ChatMessage.from_assistant(
'{"queries": ["alternative query 1", "alternative query 2", "alternative query 3"]}'
)
]
}
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=3)
expander.warm_up()
result = expander.run("original query")
assert result["queries"] == [
"alternative query 1",
"alternative query 2",
"alternative query 3",
"original query",
]
mock_chat_generator.run.assert_called_once()
def test_run_without_including_original(self, mock_chat_generator):
mock_chat_generator.run.return_value = {
"replies": [ChatMessage.from_assistant('{"queries": ["alt1", "alt2"]}')]
}
expander = QueryExpander(chat_generator=mock_chat_generator, include_original_query=False)
expander.warm_up()
result = expander.run("original")
assert result["queries"] == ["alt1", "alt2"]
def test_run_empty_query(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
result = expander.run("")
assert result["queries"] == [""]
def test_run_empty_query_no_original(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander(include_original_query=False)
expander.warm_up()
result = expander.run(" ")
assert result["queries"] == []
def test_run_whitespace_only_query(self, monkeypatch, caplog):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
with caplog.at_level(logging.WARNING):
result = expander.run("\t\n \r")
assert result["queries"] == ["\t\n \r"]
assert "Empty query provided" in caplog.text
def test_run_generator_no_replies(self, mock_chat_generator):
mock_chat_generator.run.return_value = {"replies": []}
expander = QueryExpander(chat_generator=mock_chat_generator)
expander.warm_up()
result = expander.run("test query")
assert result["queries"] == ["test query"]
def test_run_generator_exception(self, mock_chat_generator):
mock_chat_generator.run.side_effect = Exception("Generator error")
expander = QueryExpander(chat_generator=mock_chat_generator)
expander.warm_up()
result = expander.run("test query")
assert result["queries"] == ["test query"]
def test_run_invalid_json_response(self, mock_chat_generator):
mock_chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("invalid json response")]}
expander = QueryExpander(chat_generator=mock_chat_generator)
expander.warm_up()
result = expander.run("test query")
assert result["queries"] == ["test query"]
def test_parse_expanded_queries_valid_json(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
queries = expander._parse_expanded_queries('{"queries": ["query1", "query2", "query3"]}')
assert queries == ["query1", "query2", "query3"]
def test_parse_expanded_queries_invalid_json(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
queries = expander._parse_expanded_queries("not json")
assert queries == []
def test_parse_expanded_queries_empty_string(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
queries = expander._parse_expanded_queries("")
assert queries == []
def test_parse_expanded_queries_non_list_json(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
queries = expander._parse_expanded_queries('{"not": "a list"}')
assert queries == []
def test_parse_expanded_queries_mixed_types(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
queries = expander._parse_expanded_queries('{"queries": ["valid query", 123, "", "another valid"]}')
assert queries == ["valid query", "another valid"]
def test_run_query_deduplication(self, mock_chat_generator):
mock_chat_generator.run.return_value = {
"replies": [ChatMessage.from_assistant('{"queries": ["original query", "alt1", "alt2"]}')]
}
expander = QueryExpander(chat_generator=mock_chat_generator, include_original_query=True)
expander.warm_up()
result = expander.run("original query")
assert result["queries"] == ["original query", "alt1", "alt2"]
assert len(result["queries"]) == 3
def test_run_truncates_excess_queries(self, mock_chat_generator, caplog):
mock_chat_generator.run.return_value = {
"replies": [ChatMessage.from_assistant('{"queries": ["q1", "q2", "q3", "q4", "q5"]}')]
}
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=3, include_original_query=False)
expander.warm_up()
with caplog.at_level(logging.WARNING):
result = expander.run("test query")
assert len(result["queries"]) == 3
assert result["queries"] == ["q1", "q2", "q3"]
assert "Generated 5 queries but only 3 were requested" in caplog.text
assert "Truncating" in caplog.text
def test_run_with_custom_template(self, mock_chat_generator):
custom_template = """
Create {{ n_expansions }} alternative search queries for: {{ query }}
Return as JSON: {"queries": ["query1", "query2"]}
"""
mock_chat_generator.run.return_value = {
"replies": [ChatMessage.from_assistant('{"queries": ["custom alt 1", "custom alt 2"]}')]
}
expander = QueryExpander(
chat_generator=mock_chat_generator,
prompt_template=custom_template,
n_expansions=2,
include_original_query=False,
)
expander.warm_up()
result = expander.run("test query")
assert result["queries"] == ["custom alt 1", "custom alt 2"]
mock_chat_generator.run.assert_called_once()
call_args = mock_chat_generator.run.call_args[1]["messages"][0].text
assert "Create 2 alternative search queries for: test query" in call_args
assert "Return as JSON" in call_args
def test_component_output_types(self, mock_chat_generator, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
expander = QueryExpander()
expander.warm_up()
mock_chat_generator.run.return_value = {
"replies": [ChatMessage.from_assistant('{"queries": ["test1", "test2"]}')]
}
expander.chat_generator = mock_chat_generator
result = expander.run("test")
assert "queries" in result
assert isinstance(result["queries"], list)
assert all(isinstance(q, str) for q in result["queries"])
@pytest.mark.parametrize("variable", ["query", "n_expansions"])
def test_prompt_template_missing_variable(self, caplog, variable, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
if variable == "query":
template_missing_variable = "Generate {{ n_expansions }} expansions"
else:
template_missing_variable = "Generate expansions for {{ query }}"
with caplog.at_level(logging.WARNING):
QueryExpander(prompt_template=template_missing_variable)
assert f"The prompt template does not contain the '{variable}' variable" in caplog.text
assert "This may cause issues during execution" in caplog.text
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
generator = OpenAIChatGenerator(model="gpt-4.1-mini")
expander = QueryExpander(chat_generator=generator, n_expansions=2, include_original_query=False)
serialized_query_expander = expander.to_dict()
assert serialized_query_expander == {
"type": "haystack.components.query.query_expander.QueryExpander",
"init_parameters": {
"chat_generator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4.1-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"timeout": None,
"max_retries": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
"prompt_template": DEFAULT_PROMPT_TEMPLATE,
"n_expansions": 2,
"include_original_query": False,
},
}
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
data = {
"type": "haystack.components.query.query_expander.QueryExpander",
"init_parameters": {
"chat_generator": {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model": "gpt-4.1-mini",
"streaming_callback": None,
"api_base_url": None,
"organization": None,
"generation_kwargs": {},
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
"timeout": None,
"max_retries": None,
"tools": None,
"tools_strict": False,
"http_client_kwargs": None,
},
},
"prompt_template": DEFAULT_PROMPT_TEMPLATE,
"n_expansions": 2,
"include_original_query": False,
},
}
expander = QueryExpander.from_dict(data)
assert expander.n_expansions == 2
assert expander.include_original_query == False
assert expander.prompt_template == DEFAULT_PROMPT_TEMPLATE
assert isinstance(expander.chat_generator, OpenAIChatGenerator)
assert expander.chat_generator.model == "gpt-4.1-mini"
@pytest.mark.integration
class TestQueryExpanderIntegration:
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_query_expansion(self):
expander = QueryExpander(n_expansions=3)
expander.warm_up()
result = expander.run("renewable energy sources")
assert len(result["queries"]) == 4
assert all(len(q.strip()) > 0 for q in result["queries"])
assert "renewable energy sources" in result["queries"]
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_different_domains(self):
test_queries = ["machine learning algorithms", "climate change effects", "quantum computing applications"]
expander = QueryExpander(n_expansions=2, include_original_query=False)
expander.warm_up()
for query in test_queries:
result = expander.run(query)
# Should return exactly 2 expansions (no original)
assert len(result["queries"]) == 2
# Should be different from original
assert query not in result["queries"]