mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-31 09:10:15 +00:00
* importing files from experimental * linting + tests * fixing integrations tests * adding release notes * fixing imports * adding query component * adding docs to docusaurus * Update docs/pydoc/config_docusaurus/query_api.yml Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update haystack/components/query/query_expander.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update releasenotes/notes/adding-QueryExpander-MultiQueryRetriever-88c4847894ea1fd0.yaml Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * fixing code examples * adding extra unit tests to assert deduplication is working * fixing and increasing QueryExpander tets --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
401 lines
16 KiB
Python
401 lines
16 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import logging
|
|
import os
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
|
from haystack.components.query.query_expander import DEFAULT_PROMPT_TEMPLATE, QueryExpander
|
|
from haystack.dataclasses.chat_message import ChatMessage
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_chat_generator():
|
|
mock_generator = Mock(spec=OpenAIChatGenerator)
|
|
return mock_generator
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_chat_generator_with_warm_up():
|
|
mock_generator = Mock(spec=OpenAIChatGenerator)
|
|
mock_generator.warm_up = lambda: None
|
|
return mock_generator
|
|
|
|
|
|
class TestQueryExpander:
|
|
def test_init_default_generator(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
|
|
assert expander.n_expansions == 4
|
|
assert expander.include_original_query is True
|
|
assert isinstance(expander.chat_generator, OpenAIChatGenerator)
|
|
assert expander.chat_generator.model == "gpt-4.1-mini"
|
|
assert expander._prompt_builder is not None
|
|
|
|
def test_init_custom_generator(self, mock_chat_generator):
|
|
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=3)
|
|
|
|
assert expander.n_expansions == 3
|
|
assert expander.chat_generator is mock_chat_generator
|
|
|
|
def test_run_warm_up(self, mock_chat_generator_with_warm_up):
|
|
expander = QueryExpander(chat_generator=mock_chat_generator_with_warm_up)
|
|
mock_chat_generator_with_warm_up.run.return_value = {"queries": ["test query"]}
|
|
|
|
expander.warm_up()
|
|
expander.run("test query")
|
|
|
|
assert expander._is_warmed_up is True
|
|
assert expander.run("test query") == {"queries": ["test query"]}
|
|
|
|
def test_warm_up(self, mock_chat_generator):
|
|
expander = QueryExpander(chat_generator=mock_chat_generator)
|
|
expander.warm_up()
|
|
assert expander._is_warmed_up is True
|
|
|
|
def test_init_negative_expansions_raises_error(self):
|
|
with pytest.raises(ValueError, match="n_expansions must be positive"):
|
|
QueryExpander(n_expansions=-1)
|
|
|
|
def test_init_zero_expansions_raises_error(self):
|
|
with pytest.raises(ValueError, match="n_expansions must be positive"):
|
|
QueryExpander(n_expansions=0)
|
|
|
|
def test_init_custom_prompt_template(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
custom_template = "Custom template: {{ query }} with {{ n_expansions }} expansions"
|
|
expander = QueryExpander(prompt_template=custom_template)
|
|
|
|
assert expander.prompt_template == custom_template
|
|
|
|
def test_run_negative_expansions_raises_error(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
with pytest.raises(ValueError, match="n_expansions must be positive"):
|
|
expander.run("test query", n_expansions=-1)
|
|
|
|
def test_run_zero_expansions_raises_error(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander(n_expansions=4)
|
|
expander.warm_up()
|
|
with pytest.raises(ValueError, match="n_expansions must be positive"):
|
|
expander.run("test query", n_expansions=0)
|
|
|
|
def test_run_with_runtime_n_expansions_override(self, mock_chat_generator):
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [ChatMessage.from_assistant('{"queries": ["alt1", "alt2"]}')]
|
|
}
|
|
|
|
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=4, include_original_query=False)
|
|
expander.warm_up()
|
|
result = expander.run("test query", n_expansions=2)
|
|
|
|
# should request 2 expansions
|
|
call_args = mock_chat_generator.run.call_args[1]["messages"][0].text
|
|
assert "2" in call_args
|
|
assert len(result["queries"]) == 2
|
|
assert result["queries"] == ["alt1", "alt2"]
|
|
|
|
def test_run_successful_expansion(self, mock_chat_generator):
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [
|
|
ChatMessage.from_assistant(
|
|
'{"queries": ["alternative query 1", "alternative query 2", "alternative query 3"]}'
|
|
)
|
|
]
|
|
}
|
|
|
|
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=3)
|
|
expander.warm_up()
|
|
result = expander.run("original query")
|
|
|
|
assert result["queries"] == [
|
|
"alternative query 1",
|
|
"alternative query 2",
|
|
"alternative query 3",
|
|
"original query",
|
|
]
|
|
mock_chat_generator.run.assert_called_once()
|
|
|
|
def test_run_without_including_original(self, mock_chat_generator):
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [ChatMessage.from_assistant('{"queries": ["alt1", "alt2"]}')]
|
|
}
|
|
|
|
expander = QueryExpander(chat_generator=mock_chat_generator, include_original_query=False)
|
|
expander.warm_up()
|
|
result = expander.run("original")
|
|
|
|
assert result["queries"] == ["alt1", "alt2"]
|
|
|
|
def test_run_empty_query(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
result = expander.run("")
|
|
|
|
assert result["queries"] == [""]
|
|
|
|
def test_run_empty_query_no_original(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander(include_original_query=False)
|
|
expander.warm_up()
|
|
result = expander.run(" ")
|
|
|
|
assert result["queries"] == []
|
|
|
|
def test_run_whitespace_only_query(self, monkeypatch, caplog):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
with caplog.at_level(logging.WARNING):
|
|
result = expander.run("\t\n \r")
|
|
assert result["queries"] == ["\t\n \r"]
|
|
assert "Empty query provided" in caplog.text
|
|
|
|
def test_run_generator_no_replies(self, mock_chat_generator):
|
|
mock_chat_generator.run.return_value = {"replies": []}
|
|
expander = QueryExpander(chat_generator=mock_chat_generator)
|
|
expander.warm_up()
|
|
result = expander.run("test query")
|
|
|
|
assert result["queries"] == ["test query"]
|
|
|
|
def test_run_generator_exception(self, mock_chat_generator):
|
|
mock_chat_generator.run.side_effect = Exception("Generator error")
|
|
expander = QueryExpander(chat_generator=mock_chat_generator)
|
|
expander.warm_up()
|
|
result = expander.run("test query")
|
|
assert result["queries"] == ["test query"]
|
|
|
|
def test_run_invalid_json_response(self, mock_chat_generator):
|
|
mock_chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("invalid json response")]}
|
|
expander = QueryExpander(chat_generator=mock_chat_generator)
|
|
expander.warm_up()
|
|
result = expander.run("test query")
|
|
assert result["queries"] == ["test query"]
|
|
|
|
def test_parse_expanded_queries_valid_json(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
queries = expander._parse_expanded_queries('{"queries": ["query1", "query2", "query3"]}')
|
|
assert queries == ["query1", "query2", "query3"]
|
|
|
|
def test_parse_expanded_queries_invalid_json(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
queries = expander._parse_expanded_queries("not json")
|
|
assert queries == []
|
|
|
|
def test_parse_expanded_queries_empty_string(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
queries = expander._parse_expanded_queries("")
|
|
assert queries == []
|
|
|
|
def test_parse_expanded_queries_non_list_json(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
queries = expander._parse_expanded_queries('{"not": "a list"}')
|
|
assert queries == []
|
|
|
|
def test_parse_expanded_queries_mixed_types(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
queries = expander._parse_expanded_queries('{"queries": ["valid query", 123, "", "another valid"]}')
|
|
assert queries == ["valid query", "another valid"]
|
|
|
|
def test_run_query_deduplication(self, mock_chat_generator):
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [ChatMessage.from_assistant('{"queries": ["original query", "alt1", "alt2"]}')]
|
|
}
|
|
expander = QueryExpander(chat_generator=mock_chat_generator, include_original_query=True)
|
|
expander.warm_up()
|
|
result = expander.run("original query")
|
|
assert result["queries"] == ["original query", "alt1", "alt2"]
|
|
assert len(result["queries"]) == 3
|
|
|
|
def test_run_truncates_excess_queries(self, mock_chat_generator, caplog):
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [ChatMessage.from_assistant('{"queries": ["q1", "q2", "q3", "q4", "q5"]}')]
|
|
}
|
|
expander = QueryExpander(chat_generator=mock_chat_generator, n_expansions=3, include_original_query=False)
|
|
expander.warm_up()
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
result = expander.run("test query")
|
|
|
|
assert len(result["queries"]) == 3
|
|
assert result["queries"] == ["q1", "q2", "q3"]
|
|
assert "Generated 5 queries but only 3 were requested" in caplog.text
|
|
assert "Truncating" in caplog.text
|
|
|
|
def test_run_with_custom_template(self, mock_chat_generator):
|
|
custom_template = """
|
|
Create {{ n_expansions }} alternative search queries for: {{ query }}
|
|
Return as JSON: {"queries": ["query1", "query2"]}
|
|
"""
|
|
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [ChatMessage.from_assistant('{"queries": ["custom alt 1", "custom alt 2"]}')]
|
|
}
|
|
|
|
expander = QueryExpander(
|
|
chat_generator=mock_chat_generator,
|
|
prompt_template=custom_template,
|
|
n_expansions=2,
|
|
include_original_query=False,
|
|
)
|
|
expander.warm_up()
|
|
result = expander.run("test query")
|
|
|
|
assert result["queries"] == ["custom alt 1", "custom alt 2"]
|
|
|
|
mock_chat_generator.run.assert_called_once()
|
|
call_args = mock_chat_generator.run.call_args[1]["messages"][0].text
|
|
assert "Create 2 alternative search queries for: test query" in call_args
|
|
assert "Return as JSON" in call_args
|
|
|
|
def test_component_output_types(self, mock_chat_generator, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
expander = QueryExpander()
|
|
expander.warm_up()
|
|
|
|
mock_chat_generator.run.return_value = {
|
|
"replies": [ChatMessage.from_assistant('{"queries": ["test1", "test2"]}')]
|
|
}
|
|
expander.chat_generator = mock_chat_generator
|
|
|
|
result = expander.run("test")
|
|
assert "queries" in result
|
|
assert isinstance(result["queries"], list)
|
|
assert all(isinstance(q, str) for q in result["queries"])
|
|
|
|
@pytest.mark.parametrize("variable", ["query", "n_expansions"])
|
|
def test_prompt_template_missing_variable(self, caplog, variable, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
if variable == "query":
|
|
template_missing_variable = "Generate {{ n_expansions }} expansions"
|
|
else:
|
|
template_missing_variable = "Generate expansions for {{ query }}"
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
QueryExpander(prompt_template=template_missing_variable)
|
|
|
|
assert f"The prompt template does not contain the '{variable}' variable" in caplog.text
|
|
assert "This may cause issues during execution" in caplog.text
|
|
|
|
def test_to_dict(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
generator = OpenAIChatGenerator(model="gpt-4.1-mini")
|
|
expander = QueryExpander(chat_generator=generator, n_expansions=2, include_original_query=False)
|
|
|
|
serialized_query_expander = expander.to_dict()
|
|
|
|
assert serialized_query_expander == {
|
|
"type": "haystack.components.query.query_expander.QueryExpander",
|
|
"init_parameters": {
|
|
"chat_generator": {
|
|
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
|
"init_parameters": {
|
|
"model": "gpt-4.1-mini",
|
|
"streaming_callback": None,
|
|
"api_base_url": None,
|
|
"organization": None,
|
|
"generation_kwargs": {},
|
|
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
|
"timeout": None,
|
|
"max_retries": None,
|
|
"tools": None,
|
|
"tools_strict": False,
|
|
"http_client_kwargs": None,
|
|
},
|
|
},
|
|
"prompt_template": DEFAULT_PROMPT_TEMPLATE,
|
|
"n_expansions": 2,
|
|
"include_original_query": False,
|
|
},
|
|
}
|
|
|
|
def test_from_dict(self, monkeypatch):
|
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345")
|
|
|
|
data = {
|
|
"type": "haystack.components.query.query_expander.QueryExpander",
|
|
"init_parameters": {
|
|
"chat_generator": {
|
|
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
|
"init_parameters": {
|
|
"model": "gpt-4.1-mini",
|
|
"streaming_callback": None,
|
|
"api_base_url": None,
|
|
"organization": None,
|
|
"generation_kwargs": {},
|
|
"api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
|
|
"timeout": None,
|
|
"max_retries": None,
|
|
"tools": None,
|
|
"tools_strict": False,
|
|
"http_client_kwargs": None,
|
|
},
|
|
},
|
|
"prompt_template": DEFAULT_PROMPT_TEMPLATE,
|
|
"n_expansions": 2,
|
|
"include_original_query": False,
|
|
},
|
|
}
|
|
|
|
expander = QueryExpander.from_dict(data)
|
|
|
|
assert expander.n_expansions == 2
|
|
assert expander.include_original_query == False
|
|
assert expander.prompt_template == DEFAULT_PROMPT_TEMPLATE
|
|
assert isinstance(expander.chat_generator, OpenAIChatGenerator)
|
|
assert expander.chat_generator.model == "gpt-4.1-mini"
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestQueryExpanderIntegration:
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
)
|
|
def test_query_expansion(self):
|
|
expander = QueryExpander(n_expansions=3)
|
|
expander.warm_up()
|
|
result = expander.run("renewable energy sources")
|
|
|
|
assert len(result["queries"]) == 4
|
|
assert all(len(q.strip()) > 0 for q in result["queries"])
|
|
assert "renewable energy sources" in result["queries"]
|
|
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
)
|
|
def test_different_domains(self):
|
|
test_queries = ["machine learning algorithms", "climate change effects", "quantum computing applications"]
|
|
|
|
expander = QueryExpander(n_expansions=2, include_original_query=False)
|
|
expander.warm_up()
|
|
|
|
for query in test_queries:
|
|
result = expander.run(query)
|
|
|
|
# Should return exactly 2 expansions (no original)
|
|
assert len(result["queries"]) == 2
|
|
|
|
# Should be different from original
|
|
assert query not in result["queries"]
|