haystack/test/agents/test_agent.py

366 lines
14 KiB
Python
Raw Normal View History

import logging
import os
2023-03-17 18:21:14 +01:00
import re
from typing import Tuple
from test.conftest import MockRetriever, MockPromptNode
from unittest import mock
import pytest
feat: Add agent tools (#4437) * Initial commit, add search_engine * Add TopPSampler * Add more TopPSampler unit tests * Remove SearchEngineSampler (converted to TopPSampler) * Add some basic WebSearch unit tests * Rename unit tests * Add WebRetriever into agent_tools * Adjust to WebRetriever * Add WebRetriever mode [snippet|document] * Minor changes * SerperDev: add peopleAlsoAsk search results * First agent for hotpotqa * Making WebRetriever work on hotpotqa * refactor: minor WebRetriever improvements (#4377) * refactor: remove doc ids rebuild + antecipate cache * refactor: improve caching, fix Document ids * Minor WebRetriever improvements * Overlooked minor fixes * feat: add Bing API as search engine * refactor: let kwargs pass-through * feat: increase search context * check sampler result, improve batch typing * refactor: increase mypy compliance * Initial commit, add search_engine * Add TopPSampler * Add more TopPSampler unit tests * Remove SearchEngineSampler (converted to TopPSampler) * Add some basic WebSearch unit tests * Rename unit tests * Add WebRetriever into agent_tools * Adjust to WebRetriever * Add WebRetriever mode [snippet|document] * Minor changes * SerperDev: add peopleAlsoAsk search results * First agent for hotpotqa * Making WebRetriever work on hotpotqa * refactor: minor WebRetriever improvements (#4377) * refactor: remove doc ids rebuild + antecipate cache * refactor: improve caching, fix Document ids * Minor WebRetriever improvements * Overlooked minor fixes * feat: add Bing API as search engine * refactor: let kwargs pass-through * feat: increase search context * check sampler result, improve batch typing * refactor: increase mypy compliance * Fix mypy * Minor example fixes * Fix the descriptions * PR feedback updates * More fixes * TopPSampler: handle top p None value, add unit test * Add top_k to WebSearch * Use boilerpy3 instead trafilatura * Remove date finding * Add more WebRetriever docs * Refactor long methods * making the preprocessor optional * hide WebSearch and make NeuralWebSearch a pipeline * remove unused imports * add WebQAPipeline and split example into two * change example search engine to SerperDev * Turn off progress bars in WebRetriever's PreProcesssor * Agent tool examples - final updates * Add webqa test, search results ranking scores * Better answer box handling for SerperDev and SerpAPI * Minor fixes * pylint * pylint fixes * extract TopPSampler from WebRetriever * use sampler only for WebRetriever modes other than snippet * add web retriever tests * add web retriever tests * exclude rdflib@6.3.2 due to license issues * add test for preprocessed docs and kwargs examples in docstrings * Move test_webqa_pipeline to test/pipelines * change docstring for join_documents_and_scores * Use WebQAPipeline in examples/web_lfqa.py * Use WebQAPipeline in examples/web_lfqa.py * Move test_webqa_pipeline to e2e * Updated lg * Sampler added automatically in WebQAPipeline, no need to add it * Updated lg * Updated lg * :ignore Update agent tools examples to new templates (#4503) * Update examples to new templates * Add print back * fix linting and black format issues --------- Co-authored-by: Daniel Bichuetti <daniel.bichuetti@gmail.com> Co-authored-by: agnieszka-m <amarzec13@gmail.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
2023-03-27 18:14:58 +02:00
from haystack import BaseComponent, Answer
2023-03-17 18:21:14 +01:00
from haystack.agents import Agent, AgentStep
from haystack.agents.base import Tool
from haystack.errors import AgentError
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
def test_add_and_overwrite_tool():
# Add a Node as a Tool to an Agent
agent = Agent(prompt_node=MockPromptNode())
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to " "retrieve documents from your index",
)
)
assert len(agent.tools) == 1
assert "Retriever" in agent.tools
assert agent.has_tool(tool_name="Retriever")
assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseComponent)
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
)
)
# Add a Pipeline as a Tool to an Agent and overwrite the previously added Tool
retriever_pipeline = DocumentSearchPipeline(MockRetriever())
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever_pipeline,
description="useful for when you need to retrieve documents from your index",
)
)
assert len(agent.tools) == 1
assert "Retriever" in agent.tools
assert agent.has_tool(tool_name="Retriever")
assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseStandardPipeline)
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
def test_agent_chooses_no_action():
agent = Agent(prompt_node=MockPromptNode())
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
)
)
with pytest.raises(
AgentError, match=r"Could not identify the next tool or input for that tool from Agent's output.*"
):
agent.run("How many letters does the name of the town where Christelle lives have?")
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
2023-03-17 18:21:14 +01:00
def test_max_steps(caplog, monkeypatch):
# Run an Agent and stop because max_steps is reached
agent = Agent(prompt_node=MockPromptNode(), max_steps=3)
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
output_variable="documents",
)
)
# Let the Agent always choose "Retriever" as the Tool with "" as input
def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]:
return "Retriever", ""
2023-03-17 18:21:14 +01:00
monkeypatch.setattr(AgentStep, "extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
2023-03-17 18:21:14 +01:00
# Using max_steps as specified in the Agent's init method
with caplog.at_level(logging.WARN, logger="haystack.agents"):
result = agent.run("Where does Christelle live?")
assert result["answers"] == [Answer(answer="", type="generative")]
assert "maximum number of iterations (3)" in caplog.text.lower()
2023-03-17 18:21:14 +01:00
# Setting max_steps in the Agent's run method
with caplog.at_level(logging.WARN, logger="haystack.agents"):
2023-03-17 18:21:14 +01:00
result = agent.run("Where does Christelle live?", max_steps=2)
assert result["answers"] == [Answer(answer="", type="generative")]
assert "maximum number of iterations (2)" in caplog.text.lower()
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
def test_run_tool():
agent = Agent(prompt_node=MockPromptNode())
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
output_variable="documents",
)
)
2023-03-17 18:21:14 +01:00
pn_response = "need to find out what city he was born.\nTool: Retriever\nTool Input: Where was Jeremy McKinnon born"
step = AgentStep(prompt_node_response=pn_response)
result = agent._run_tool(step)
assert result == "[]" # empty list of documents
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
def test_extract_tool_name_and_tool_input():
2023-03-17 18:21:14 +01:00
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*'
pn_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
2023-03-17 18:21:14 +01:00
step = AgentStep(prompt_node_response=pn_response)
tool_name, tool_input = step.extract_tool_name_and_tool_input(tool_pattern=tool_pattern)
assert tool_name == "Search" and tool_input == "Where was Jeremy McKinnon born"
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
def test_extract_final_answer():
2023-03-17 18:21:14 +01:00
match_examples = [
"have the final answer to the question.\nFinal Answer: Florida",
"Final Answer: 42 is the answer",
"Final Answer: 1234",
"Final Answer: Answer",
"Final Answer: This list: one and two and three",
"Final Answer:42",
"Final Answer: ",
"Final Answer: The answer is 99 ",
]
expected_answers = [
"Florida",
"42 is the answer",
"1234",
"Answer",
"This list: one and two and three",
"42",
"",
"The answer is 99",
]
for example, expected_answer in zip(match_examples, expected_answers):
agent_step = AgentStep(prompt_node_response=example)
final_answer = agent_step.extract_final_answer()
assert final_answer == expected_answer
2023-02-23 13:03:45 +01:00
@pytest.mark.unit
def test_format_answer():
2023-03-17 18:21:14 +01:00
step = AgentStep(prompt_node_response="have the final answer to the question.\nFinal Answer: Florida")
formatted_answer = step.final_answer(query="query")
assert formatted_answer["query"] == "query"
2023-03-17 18:21:14 +01:00
assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")]
@pytest.mark.unit
def test_final_answer_regex():
match_examples = [
"Final Answer: 42 is the answer",
"Final Answer: 1234",
"Final Answer: Answer",
"Final Answer: This list: one and two and three",
"Final Answer:42",
"Final Answer: ",
"Final Answer: The answer is 99 ",
]
non_match_examples = ["Final answer: 42 is the answer", "Final Answer", "The final answer is: 100"]
final_answer_pattern = r"Final Answer\s*:\s*(.*)"
for example in match_examples:
match = re.match(final_answer_pattern, example)
assert match is not None
for example in non_match_examples:
match = re.match(final_answer_pattern, example)
assert match is None
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
def test_tool_result_extraction(reader, retriever_with_docs):
# Test that the result of a Tool is correctly extracted as a string
# Pipeline as a Tool
search = ExtractiveQAPipeline(reader, retriever_with_docs)
t = Tool(
name="Search",
pipeline_or_node=search,
description="useful for when you need to answer "
"questions about where people live. You "
"should ask targeted questions",
output_variable="answers",
)
result = t.run("Where does Christelle live?")
assert isinstance(result, str)
assert result == "Paris" or result == "Madrid"
# PromptNode as a Tool
feat: PromptTemplate extensions (#4378) * use outputshapers in prompttemplate * fix pylint * first iteration on regex * implement new promptnode syntax based on f-strings * finish fstring implementation * add additional tests * add security tests * fix mypy * fix pylint * fix test_prompt_templates * fix test_prompt_template_repr * fix test_prompt_node_with_custom_invocation_layer * fix test_invalid_template * more security tests * fix test_complex_pipeline_with_all_features * fix agent tests * refactor get_prompt_template * fix test_prompt_template_syntax_parser * fix test_complex_pipeline_with_all_features * allow functions in comprehensions * break out of fstring test * fix additional tests * mark new tests as unit tests * fix agents tests * convert missing templates * proper use of get_prompt_template * refactor and add docstrings * fix tests * fix pylint * fix agents test * fix tests * refactor globals * make allowed functions configurable via env variable * better dummy variable * fix special alias * don't replace special char variables * more special chars, better docstrings * cherrypick fix audio tests * fix test * rework shapers * fix pylint * fix tests * add new templates * add reference parsing * add more shaper tests * add tests for join and to_string * fix pylint * fix pylint * fix pylint for real * auto fill shaper function params * fix reference parsing for multiple references * fix output variable inference * consolidate qa prompt template output and make shaper work per-document * fix types after merge * introduce output_parser * fix tests * better docstring * rename RegexAnswerParser to AnswerParser * better docstrings
2023-03-27 12:14:11 +02:00
pt = PromptTemplate("test", "Here is a question: {query}, Answer:")
pn = PromptNode(default_prompt_template=pt)
t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
result = t.run(tool_input="What is the capital of Germany?")
assert isinstance(result, str)
assert "berlin" in result.lower()
# Retriever as a Tool
t = Tool(
name="Retriever",
pipeline_or_node=retriever_with_docs,
description="useful for when you need to retrieve documents from your index",
output_variable="documents",
)
result = t.run(tool_input="Where does Christelle live?")
assert isinstance(result, str)
assert "Christelle" in result
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
search = ExtractiveQAPipeline(reader, retriever_with_docs)
prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"))
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
counter = PromptNode(
model_name_or_path=prompt_model,
default_prompt_template=PromptTemplate(
name="calculator_template",
prompt_text="When I give you a word, respond with the number of characters that this word contains.\n"
"Word: Rome\nLength: 4\n"
"Word: Arles\nLength: 5\n"
"Word: Berlin\nLength: 6\n"
feat: PromptTemplate extensions (#4378) * use outputshapers in prompttemplate * fix pylint * first iteration on regex * implement new promptnode syntax based on f-strings * finish fstring implementation * add additional tests * add security tests * fix mypy * fix pylint * fix test_prompt_templates * fix test_prompt_template_repr * fix test_prompt_node_with_custom_invocation_layer * fix test_invalid_template * more security tests * fix test_complex_pipeline_with_all_features * fix agent tests * refactor get_prompt_template * fix test_prompt_template_syntax_parser * fix test_complex_pipeline_with_all_features * allow functions in comprehensions * break out of fstring test * fix additional tests * mark new tests as unit tests * fix agents tests * convert missing templates * proper use of get_prompt_template * refactor and add docstrings * fix tests * fix pylint * fix agents test * fix tests * refactor globals * make allowed functions configurable via env variable * better dummy variable * fix special alias * don't replace special char variables * more special chars, better docstrings * cherrypick fix audio tests * fix test * rework shapers * fix pylint * fix tests * add new templates * add reference parsing * add more shaper tests * add tests for join and to_string * fix pylint * fix pylint * fix pylint for real * auto fill shaper function params * fix reference parsing for multiple references * fix output variable inference * consolidate qa prompt template output and make shaper work per-document * fix types after merge * introduce output_parser * fix tests * better docstring * rename RegexAnswerParser to AnswerParser * better docstrings
2023-03-27 12:14:11 +02:00
"Word: {query}?\nLength: ",
),
)
agent = Agent(prompt_node=prompt_node)
agent.add_tool(
Tool(
name="Search",
pipeline_or_node=search,
description="useful for when you need to answer "
"questions about where people live. You "
"should ask targeted questions",
output_variable="answers",
)
)
agent.add_tool(
Tool(
name="Count",
pipeline_or_node=counter,
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
)
)
# TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors
result = agent.run("How many characters are in the word Madrid?")
assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"])
result = agent.run("How many letters does the name of the town where Christelle lives have?")
assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"])
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
search = ExtractiveQAPipeline(reader, retriever_with_docs)
prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"))
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
counter = PromptNode(
model_name_or_path=prompt_model,
default_prompt_template=PromptTemplate(
name="calculator_template",
prompt_text="When I give you a word, respond with the number of characters that this word contains.\n"
"Word: Rome\nLength: 4\n"
"Word: Arles\nLength: 5\n"
"Word: Berlin\nLength: 6\n"
feat: PromptTemplate extensions (#4378) * use outputshapers in prompttemplate * fix pylint * first iteration on regex * implement new promptnode syntax based on f-strings * finish fstring implementation * add additional tests * add security tests * fix mypy * fix pylint * fix test_prompt_templates * fix test_prompt_template_repr * fix test_prompt_node_with_custom_invocation_layer * fix test_invalid_template * more security tests * fix test_complex_pipeline_with_all_features * fix agent tests * refactor get_prompt_template * fix test_prompt_template_syntax_parser * fix test_complex_pipeline_with_all_features * allow functions in comprehensions * break out of fstring test * fix additional tests * mark new tests as unit tests * fix agents tests * convert missing templates * proper use of get_prompt_template * refactor and add docstrings * fix tests * fix pylint * fix agents test * fix tests * refactor globals * make allowed functions configurable via env variable * better dummy variable * fix special alias * don't replace special char variables * more special chars, better docstrings * cherrypick fix audio tests * fix test * rework shapers * fix pylint * fix tests * add new templates * add reference parsing * add more shaper tests * add tests for join and to_string * fix pylint * fix pylint * fix pylint for real * auto fill shaper function params * fix reference parsing for multiple references * fix output variable inference * consolidate qa prompt template output and make shaper work per-document * fix types after merge * introduce output_parser * fix tests * better docstring * rename RegexAnswerParser to AnswerParser * better docstrings
2023-03-27 12:14:11 +02:00
"Word: {query}\nLength: ",
),
)
agent = Agent(prompt_node=prompt_node)
agent.add_tool(
Tool(
name="Search",
pipeline_or_node=search,
description="useful for when you need to answer "
"questions about where people live. You "
"should ask targeted questions",
output_variable="answers",
)
)
agent.add_tool(
Tool(
name="Count",
pipeline_or_node=counter,
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
)
)
results = agent.run_batch(
queries=[
"How many characters are in the word Madrid?",
"How many letters does the name of the town where Christelle lives have?",
]
)
# TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors
assert any(digit in results["answers"][0][0].answer for digit in ["5", "6", "five", "six"])
assert any(digit in results["answers"][1][0].answer for digit in ["5", "6", "five", "six"])
@pytest.mark.unit
def test_update_hash():
agent = Agent(prompt_node=MockPromptNode(), prompt_template=mock.Mock())
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
agent.add_tool(
Tool(
name="Search",
pipeline_or_node=mock.Mock(),
description="useful for when you need to answer "
"questions about where people live. You "
"should ask targeted questions",
output_variable="answers",
)
)
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
agent.add_tool(
Tool(
name="Count",
pipeline_or_node=mock.Mock(),
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
)
)
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
agent.update_hash()
assert agent.hash == "5ac8eca2f92c9545adcce3682b80d4c5"