haystack/test/agents/test_agent_step.py

110 lines
3.4 KiB
Python
Raw Normal View History

2023-05-17 15:19:09 +02:00
import pytest
from haystack import Answer
from haystack.agents import AgentStep
from haystack.errors import AgentError
@pytest.fixture
def agent_step():
return AgentStep(
current_step=1, max_steps=10, final_answer_pattern=None, prompt_node_response="Hello", transcript="Hello"
)
@pytest.mark.unit
def test_create_next_step(agent_step):
# Test normal case
next_step = agent_step.create_next_step(["Hello again"])
assert next_step.current_step == 2
assert next_step.prompt_node_response == "Hello again"
assert next_step.transcript == "Hello"
# Test with invalid prompt_node_response
with pytest.raises(AgentError):
agent_step.create_next_step({})
# Test with empty prompt_node_response
with pytest.raises(AgentError):
agent_step.create_next_step([])
@pytest.mark.unit
def test_final_answer(agent_step):
# Test normal case
result = agent_step.final_answer("query")
assert result["query"] == "query"
assert isinstance(result["answers"][0], Answer)
assert result["answers"][0].answer == "Hello"
assert result["answers"][0].type == "generative"
assert result["transcript"] == "Hello"
# Test with max_steps reached
agent_step.current_step = 11
result = agent_step.final_answer("query")
assert result["answers"][0].answer == ""
@pytest.mark.unit
def test_is_last():
# Test is last, and it is last because of valid prompt_node_response and default final_answer_pattern
agent_step = AgentStep(current_step=1, max_steps=10, prompt_node_response="Hello", transcript="Hello")
assert agent_step.is_last()
# Test not last
agent_step.current_step = 1
agent_step.prompt_node_response = "final answer not satisfying pattern"
agent_step.final_answer_pattern = r"Final Answer\s*:\s*(.*)"
assert not agent_step.is_last()
# Test border cases for max_steps
agent_step.current_step = 9
assert not agent_step.is_last()
agent_step.current_step = 10
assert not agent_step.is_last()
# Test when last due to max_steps
agent_step.current_step = 11
assert agent_step.is_last()
@pytest.mark.unit
def test_completed(agent_step):
# Test without observation
agent_step.completed(None)
assert agent_step.transcript == "HelloHello"
# Test with observation, adds Hello from prompt_node_response
agent_step.completed("observation")
assert agent_step.transcript == "HelloHelloHello\nObservation: observation\nThought:"
@pytest.mark.unit
def test_repr(agent_step):
assert repr(agent_step) == (
"AgentStep(current_step=1, max_steps=10, "
"prompt_node_response=Hello, final_answer_pattern=^([\\s\\S]+)$, "
"transcript=Hello)"
)
@pytest.mark.unit
def test_parse_final_answer(agent_step):
# Test when pattern matches
assert agent_step.parse_final_answer() == "Hello"
# Test when pattern does not match
agent_step.final_answer_pattern = "goodbye"
assert agent_step.parse_final_answer() is None
@pytest.mark.unit
def test_format_react_answer():
step = AgentStep(
final_answer_pattern=r"Final Answer\s*:\s*(.*)",
prompt_node_response="have the final answer to the question.\nFinal Answer: Florida",
)
formatted_answer = step.final_answer(query="query")
assert formatted_answer["query"] == "query"
assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")]