autogen/test/agentchat/contrib/test_teachable_agent.py
Maxim Saplin c80df8acab
Skip tests that depend on OpenAI via --skip-openai (#1097)
* --skip-openai

* All tests pass

* Update build.yml

* Update Contribute.md

* Fix for failing Ubuntu tests

* More tests skipped, fixing 3.10 build

* Apply suggestions from code review

Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>

* Added more comments

* fixed test__wrap_function_*

---------

Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
Co-authored-by: Davor Runje <davor@airt.ai>
2023-12-31 19:37:21 +00:00

221 lines
9.0 KiB
Python

import pytest
import os
import sys
from autogen import ConversableAgent, config_list_from_json
from conftest import skip_openai
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC # noqa: E402
try:
from openai import OpenAI
from autogen.agentchat.contrib.teachable_agent import TeachableAgent
except ImportError:
skip = True
else:
skip = False or skip_openai
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
# Set verbosity levels to maximize code coverage.
qa_verbosity = 0 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
skill_verbosity = 3 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
assert_on_error = False # GPT-4 nearly always succeeds on these unit tests, but GPT-3.5 is a bit less reliable.
recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
cache_seed = None # Use an int to seed the response cache. Use None to disable caching.
# Specify the model to use by uncommenting one of the following lines.
# filter_dict={"model": ["gpt-4-0613"]}
# filter_dict={"model": ["gpt-3.5-turbo-0613"]}
# filter_dict={"model": ["gpt-4"]}
filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
def create_teachable_agent(reset_db=False, verbosity=0):
"""Instantiates a TeachableAgent using the settings from the top of this file."""
# Load LLM inference endpoints from an env variable or a file
# See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints
# and OAI_CONFIG_LIST_sample
config_list = config_list_from_json(env_or_file=OAI_CONFIG_LIST, filter_dict=filter_dict, file_location=KEY_LOC)
teachable_agent = TeachableAgent(
name="teachableagent",
llm_config={"config_list": config_list, "timeout": 120, "cache_seed": cache_seed},
teach_config={
"verbosity": verbosity,
"reset_db": reset_db,
"path_to_db_dir": "./tmp/teachable_agent_db",
"recall_threshold": recall_threshold,
},
)
return teachable_agent
def check_agent_response(teachable_agent, user, correct_answer):
"""Checks whether the agent's response contains the correct answer, and returns the number of errors (1 or 0)."""
agent_response = user.last_message(teachable_agent)["content"]
if correct_answer not in agent_response:
print(colored(f"\nTEST FAILED: EXPECTED ANSWER {correct_answer} NOT FOUND IN AGENT RESPONSE", "light_red"))
if assert_on_error:
assert correct_answer in agent_response
return 1
else:
print(colored(f"\nTEST PASSED: EXPECTED ANSWER {correct_answer} FOUND IN AGENT RESPONSE", "light_cyan"))
return 0
def use_question_answer_phrasing():
"""Tests whether the teachable agent can answer a question after being taught the answer in a previous chat."""
print(colored("\nTEST QUESTION-ANSWER PHRASING", "light_cyan"))
num_errors, num_tests = 0, 0
teachable_agent = create_teachable_agent(
reset_db=True, verbosity=qa_verbosity
) # For a clean test, clear the agent's memory.
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
teachable_agent.prepopulate_db()
# Ask the teachable agent to do something using terminology it doesn't understand.
user.initiate_chat(recipient=teachable_agent, message="What is the twist of 5 and 7?")
# Explain the terminology to the teachable agent.
user.send(
recipient=teachable_agent,
message="Actually, the twist of two or more numbers is their product minus their sum. Try again.",
)
num_errors += check_agent_response(teachable_agent, user, "23")
num_tests += 1
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Now start a new chat to clear the context, and require the teachable agent to use its new knowledge.
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
user.initiate_chat(recipient=teachable_agent, message="What's the twist of 8 and 3 and 2?")
num_errors += check_agent_response(teachable_agent, user, "35")
num_tests += 1
# Wrap up.
teachable_agent.close_db()
return num_errors, num_tests
def use_task_advice_pair_phrasing():
"""Tests whether the teachable agent can demonstrate a new skill after being taught a task-advice pair in a previous chat."""
print(colored("\nTEST TASK-ADVICE PHRASING", "light_cyan"))
num_errors, num_tests = 0, 0
teachable_agent = create_teachable_agent(
reset_db=True, verbosity=skill_verbosity # For a clean test, clear the teachable agent's memory.
)
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
teachable_agent.prepopulate_db()
# Ask the teachable agent to do something, and provide some helpful advice.
user.initiate_chat(
recipient=teachable_agent,
message="Compute the twist of 5 and 7. Here's a hint: The twist of two or more numbers is their product minus their sum.",
)
num_errors += check_agent_response(teachable_agent, user, "23")
num_tests += 1
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Now start a new chat to clear the context, and require the teachable agent to use its new knowledge.
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
user.initiate_chat(recipient=teachable_agent, message="Please calculate the twist of 8 and 3 and 2.")
num_errors += check_agent_response(teachable_agent, user, "35")
num_tests += 1
# Wrap up.
teachable_agent.close_db()
return num_errors, num_tests
@pytest.mark.skipif(
skip,
reason="do not run if dependency is not installed",
)
def test_teachability_code_paths():
"""Runs this file's unit tests."""
total_num_errors, total_num_tests = 0, 0
num_trials = 1 # Set to a higher number to get a more accurate error rate.
for trial in range(num_trials):
num_errors, num_tests = use_question_answer_phrasing()
total_num_errors += num_errors
total_num_tests += num_tests
num_errors, num_tests = use_task_advice_pair_phrasing()
total_num_errors += num_errors
total_num_tests += num_tests
print(colored(f"\nTRIAL {trial + 1} OF {num_trials} FINISHED", "light_cyan"))
if total_num_errors == 0:
print(colored("\nTEACHABLE AGENT TESTS FINISHED WITH ZERO ERRORS", "light_cyan"))
else:
print(
colored(
f"\nTEACHABLE AGENT TESTS FINISHED WITH {total_num_errors} / {total_num_tests} TOTAL ERRORS ({100.0 * total_num_errors / total_num_tests}%)",
"light_red",
)
)
@pytest.mark.skipif(
skip,
reason="do not run if dependency is not installed",
)
def test_teachability_accuracy():
"""A very cheap and fast test of teachability accuracy."""
print(colored("\nTEST TEACHABILITY ACCURACY", "light_cyan"))
num_trials = 10 # The expected probability of failure is about 0.3 on each trial.
for trial in range(num_trials):
teachable_agent = create_teachable_agent(
reset_db=True, verbosity=0
) # For a clean test, clear the agent's memory.
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
teachable_agent.prepopulate_db()
# Tell the teachable agent something it wouldn't already know.
user.initiate_chat(recipient=teachable_agent, message="My favorite color is teal.")
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Now start a new chat to clear the context, and ask the teachable agent about the new information.
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
user.initiate_chat(recipient=teachable_agent, message="What's my favorite color?")
num_errors = check_agent_response(teachable_agent, user, "teal")
print(colored(f"\nTRIAL {trial + 1} OF {num_trials} FINISHED", "light_cyan"))
# Wrap up.
teachable_agent.close_db()
# Exit on the first success.
if num_errors == 0:
return
# All trials failed.
assert False, "test_teachability_accuracy() failed on all {} trials.".format(num_trials)
if __name__ == "__main__":
"""Runs this file's unit tests from the command line."""
test_teachability_code_paths()
test_teachability_accuracy()