mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-13 19:17:56 +00:00
110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
![]() |
import pytest
|
||
|
|
||
|
from haystack import Answer
|
||
|
from haystack.agents import AgentStep
|
||
|
from haystack.errors import AgentError
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def agent_step():
|
||
|
return AgentStep(
|
||
|
current_step=1, max_steps=10, final_answer_pattern=None, prompt_node_response="Hello", transcript="Hello"
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_create_next_step(agent_step):
|
||
|
# Test normal case
|
||
|
next_step = agent_step.create_next_step(["Hello again"])
|
||
|
assert next_step.current_step == 2
|
||
|
assert next_step.prompt_node_response == "Hello again"
|
||
|
assert next_step.transcript == "Hello"
|
||
|
|
||
|
# Test with invalid prompt_node_response
|
||
|
with pytest.raises(AgentError):
|
||
|
agent_step.create_next_step({})
|
||
|
|
||
|
# Test with empty prompt_node_response
|
||
|
with pytest.raises(AgentError):
|
||
|
agent_step.create_next_step([])
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_final_answer(agent_step):
|
||
|
# Test normal case
|
||
|
result = agent_step.final_answer("query")
|
||
|
assert result["query"] == "query"
|
||
|
assert isinstance(result["answers"][0], Answer)
|
||
|
assert result["answers"][0].answer == "Hello"
|
||
|
assert result["answers"][0].type == "generative"
|
||
|
assert result["transcript"] == "Hello"
|
||
|
|
||
|
# Test with max_steps reached
|
||
|
agent_step.current_step = 11
|
||
|
result = agent_step.final_answer("query")
|
||
|
assert result["answers"][0].answer == ""
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_is_last():
|
||
|
# Test is last, and it is last because of valid prompt_node_response and default final_answer_pattern
|
||
|
agent_step = AgentStep(current_step=1, max_steps=10, prompt_node_response="Hello", transcript="Hello")
|
||
|
assert agent_step.is_last()
|
||
|
|
||
|
# Test not last
|
||
|
agent_step.current_step = 1
|
||
|
agent_step.prompt_node_response = "final answer not satisfying pattern"
|
||
|
agent_step.final_answer_pattern = r"Final Answer\s*:\s*(.*)"
|
||
|
assert not agent_step.is_last()
|
||
|
|
||
|
# Test border cases for max_steps
|
||
|
agent_step.current_step = 9
|
||
|
assert not agent_step.is_last()
|
||
|
agent_step.current_step = 10
|
||
|
assert not agent_step.is_last()
|
||
|
|
||
|
# Test when last due to max_steps
|
||
|
agent_step.current_step = 11
|
||
|
assert agent_step.is_last()
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_completed(agent_step):
|
||
|
# Test without observation
|
||
|
agent_step.completed(None)
|
||
|
assert agent_step.transcript == "HelloHello"
|
||
|
|
||
|
# Test with observation, adds Hello from prompt_node_response
|
||
|
agent_step.completed("observation")
|
||
|
assert agent_step.transcript == "HelloHelloHello\nObservation: observation\nThought:"
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_repr(agent_step):
|
||
|
assert repr(agent_step) == (
|
||
|
"AgentStep(current_step=1, max_steps=10, "
|
||
|
"prompt_node_response=Hello, final_answer_pattern=^([\\s\\S]+)$, "
|
||
|
"transcript=Hello)"
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_parse_final_answer(agent_step):
|
||
|
# Test when pattern matches
|
||
|
assert agent_step.parse_final_answer() == "Hello"
|
||
|
|
||
|
# Test when pattern does not match
|
||
|
agent_step.final_answer_pattern = "goodbye"
|
||
|
assert agent_step.parse_final_answer() is None
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_format_react_answer():
|
||
|
step = AgentStep(
|
||
|
final_answer_pattern=r"Final Answer\s*:\s*(.*)",
|
||
|
prompt_node_response="have the final answer to the question.\nFinal Answer: Florida",
|
||
|
)
|
||
|
formatted_answer = step.final_answer(query="query")
|
||
|
assert formatted_answer["query"] == "query"
|
||
|
assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")]
|