General Enhancements in agentchat 2.0 (#1906)

* work in progress

* wip

* groupchat type hints

* clean up

* formatting

* formatting

* clean up

* address comments

* better comment

* updates docstring a_send

* resolve comments

* agent.py back to original format

* resolve more comments

* rename carryover type exception

* revert next_agent changes + keeping UndefinedNextagent

* fixed ciruclar dependencies?

* fix cache tests

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Wael Karkoub 2024-03-09 16:15:19 +01:00 committed by GitHub
parent 0e1a4b9fdd
commit 29b9c80c40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 195 additions and 138 deletions

View File

@ -11,6 +11,8 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Ty
import warnings
from openai import BadRequestError
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
@ -77,7 +79,7 @@ class ConversableAgent(LLMAgent):
system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE",
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
@ -576,7 +578,7 @@ class ConversableAgent(LLMAgent):
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> ChatResult:
):
"""Send a message to another agent.
Args:
@ -608,9 +610,6 @@ class ConversableAgent(LLMAgent):
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
Returns:
ChatResult: a ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
@ -629,7 +628,7 @@ class ConversableAgent(LLMAgent):
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> ChatResult:
):
"""(async) Send a message to another agent.
Args:
@ -661,9 +660,6 @@ class ConversableAgent(LLMAgent):
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
Returns:
ChatResult: an ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
@ -857,7 +853,7 @@ class ConversableAgent(LLMAgent):
def initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
max_turns: Optional[int] = None,
@ -946,7 +942,7 @@ class ConversableAgent(LLMAgent):
async def a_initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
max_turns: Optional[int] = None,
@ -1524,8 +1520,6 @@ class ConversableAgent(LLMAgent):
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
# Function implementation...
if config is None:
config = self
if messages is None:
@ -1839,6 +1833,7 @@ class ConversableAgent(LLMAgent):
reply_func = reply_func_tuple["reply_func"]
if "exclude" in kwargs and reply_func in kwargs["exclude"]:
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
if inspect.iscoroutinefunction(reply_func):
final, reply = await reply_func(
@ -1850,7 +1845,7 @@ class ConversableAgent(LLMAgent):
return reply
return self._default_auto_reply
def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Agent) -> bool:
def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool:
"""Check if the sender matches the trigger.
Args:
@ -1867,6 +1862,8 @@ class ConversableAgent(LLMAgent):
if trigger is None:
return sender is None
elif isinstance(trigger, str):
if sender is None:
raise SenderRequired()
return trigger == sender.name
elif isinstance(trigger, type):
return isinstance(sender, trigger)
@ -1875,7 +1872,7 @@ class ConversableAgent(LLMAgent):
return trigger == sender
elif isinstance(trigger, Callable):
rst = trigger(sender)
assert rst in [True, False], f"trigger {trigger} must return a boolean value."
assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value."
return rst
elif isinstance(trigger, list):
return any(self._match_trigger(t, sender) for t in trigger)
@ -2154,7 +2151,7 @@ class ConversableAgent(LLMAgent):
elif isinstance(carryover, list):
context["message"] = context["message"] + "\nContext: \n" + ("\n").join([t for t in carryover])
else:
raise warnings.warn(
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
@ -2212,6 +2209,11 @@ class ConversableAgent(LLMAgent):
func for func in self.llm_config["functions"] if func["name"] != func_sig
]
else:
if not isinstance(func_sig, dict):
raise ValueError(
f"The function signature must be of the type dict. Received function signature type {type(func_sig)}"
)
self._assert_valid_name(func_sig["name"])
if "functions" in self.llm_config.keys():
self.llm_config["functions"] = [
@ -2248,6 +2250,10 @@ class ConversableAgent(LLMAgent):
tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig
]
else:
if not isinstance(tool_sig, dict):
raise ValueError(
f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
)
self._assert_valid_name(tool_sig["function"]["name"])
if "tools" in self.llm_config.keys():
self.llm_config["tools"] = [

View File

@ -3,27 +3,19 @@ import random
import re
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple, Callable
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from autogen.agentchat.agent import Agent
from autogen.agentchat.conversable_agent import ConversableAgent
from ..code_utils import content_str
from ..exception_utils import AgentNameConflict
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..runtime_logging import logging_enabled, log_new_agent
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
from ..runtime_logging import log_new_agent, logging_enabled
logger = logging.getLogger(__name__)
class NoEligibleSpeakerException(Exception):
"""Exception raised for early termination of a GroupChat."""
def __init__(self, message="No eligible speakers."):
self.message = message
super().__init__(self.message)
@dataclass
class GroupChat:
"""(In preview) A group chat class that contains the following data fields:
@ -76,10 +68,10 @@ class GroupChat:
max_round: Optional[int] = 10
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Optional[Union[str, Callable]] = "auto"
speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto"
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Optional[str] = None
speaker_transitions_type: Literal["allowed", "disallowed", None] = None
enable_clear_history: Optional[bool] = False
send_introductions: Optional[bool] = False
@ -212,6 +204,10 @@ class GroupChat:
if agents is None:
agents = self.agents
# Ensure the provided list of agents is a subset of self.agents
if not set(agents).issubset(set(self.agents)):
raise UndefinedNextAgent()
# What index is the agent? (-1 if not present)
idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1
@ -224,6 +220,9 @@ class GroupChat:
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]
# Explicitly handle cases where no valid next agent exists in the provided subset.
raise UndefinedNextAgent()
def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
if agents is None:
@ -295,9 +294,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
if isinstance(self.speaker_selection_method, Callable):
selected_agent = self.speaker_selection_method(last_speaker, self)
if selected_agent is None:
raise NoEligibleSpeakerException(
"Custom speaker selection function returned None. Terminating conversation."
)
raise NoEligibleSpeaker("Custom speaker selection function returned None. Terminating conversation.")
elif isinstance(selected_agent, Agent):
if selected_agent in self.agents:
return selected_agent, self.agents, None
@ -378,9 +375,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
# this condition means last_speaker is a sink in the graph, then no agents are eligible
if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group:
raise NoEligibleSpeakerException(
f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict."
)
raise NoEligibleSpeaker(f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict.")
# last_speaker is not in the group, so all agents are eligible
elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group:
graph_eligible_agents = []
@ -618,7 +613,7 @@ class GroupChatManager(ConversableAgent):
else:
# admin agent is not found in the participants
raise
except NoEligibleSpeakerException:
except NoEligibleSpeaker:
# No eligible speaker, terminate the conversation
break

View File

@ -29,7 +29,7 @@ class UserProxyAgent(ConversableAgent):
name: str,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "ALWAYS",
human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Tuple, Callable
from typing import Any, List, Dict, Tuple, Callable
from .agent import Agent
@ -53,7 +53,7 @@ def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str,
If none of the agents incurred any cost (not having a client), then the total_usage_summary and actual_usage_summary will be `{'total_cost': 0}`.
"""
def aggregate_summary(usage_summary: Dict[str, any], agent_summary: Dict[str, any]) -> None:
def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None:
if agent_summary is None:
return
usage_summary["total_cost"] += agent_summary.get("total_cost", 0)

View File

@ -1,3 +1,37 @@
class AgentNameConflict(Exception):
def __init__(self, msg="Found multiple agents with the same name.", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
class NoEligibleSpeaker(Exception):
"""Exception raised for early termination of a GroupChat."""
def __init__(self, message="No eligible speakers."):
self.message = message
super().__init__(self.message)
class SenderRequired(Exception):
"""Exception raised when the sender is required but not provided."""
def __init__(self, message="Sender is required but not provided."):
self.message = message
super().__init__(self.message)
class InvalidCarryOverType(Exception):
"""Exception raised when the carryover type is invalid."""
def __init__(
self, message="Carryover should be a string or a list of strings. Not adding carryover to the message."
):
self.message = message
super().__init__(self.message)
class UndefinedNextAgent(Exception):
"""Exception raised when the provided next agents list does not overlap with agents in the group."""
def __init__(self, message="The provided agents list does not overlap with agents in the group."):
self.message = message
super().__init__(self.message)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List
import logging
from autogen.agentchat.groupchat import Agent

View File

@ -1,12 +1,14 @@
import os
import sys
import tempfile
import time
import pytest
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
import autogen
from autogen.agentchat import AssistantAgent, UserProxyAgent
from autogen.cache import Cache
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST, here
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai, skip_redis # noqa: E402
@ -120,41 +122,41 @@ def run_conversation(cache_seed, human_input_mode="NEVER", max_consecutive_auto_
"config_list": config_list,
"max_tokens": 1024,
}
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": f"{here}/test_agent_scripts",
"use_docker": "python:3",
"timeout": 60,
},
llm_config=llm_config,
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
with tempfile.TemporaryDirectory() as work_dir:
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": work_dir,
"use_docker": "python:3",
"timeout": 60,
},
llm_config=llm_config,
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
user.initiate_chat(assistant, message="TERMINATE", cache=cache)
# should terminate without sending any message
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE"
coding_task = "Print hello world to a file called hello.txt"
user.initiate_chat(assistant, message="TERMINATE", cache=cache)
# should terminate without sending any message
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE"
coding_task = "Print hello world to a file called hello.txt"
# track how long this takes
user.initiate_chat(assistant, message=coding_task, cache=cache)
return user.chat_messages[assistant]
# track how long this takes
user.initiate_chat(assistant, message=coding_task, cache=cache)
return user.chat_messages[assistant]
def run_groupchat_conversation(cache, human_input_mode="NEVER", max_consecutive_auto_reply=5):
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
here = os.path.abspath(os.path.dirname(__file__))
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
@ -167,41 +169,43 @@ def run_groupchat_conversation(cache, human_input_mode="NEVER", max_consecutive_
"config_list": config_list,
"max_tokens": 1024,
}
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
planner = AssistantAgent(
"planner",
llm_config=llm_config,
)
with tempfile.TemporaryDirectory() as work_dir:
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": f"{here}/test_agent_scripts",
"use_docker": "python:3",
"timeout": 60,
},
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
planner = AssistantAgent(
"planner",
llm_config=llm_config,
)
group_chat = autogen.GroupChat(
agents=[planner, assistant, user],
messages=[],
max_round=4,
speaker_selection_method="round_robin",
)
manager = autogen.GroupChatManager(groupchat=group_chat, llm_config=llm_config)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": work_dir,
"use_docker": "python:3",
"timeout": 60,
},
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
coding_task = "Print hello world to a file called hello.txt"
group_chat = autogen.GroupChat(
agents=[planner, assistant, user],
messages=[],
max_round=4,
speaker_selection_method="round_robin",
)
manager = autogen.GroupChatManager(groupchat=group_chat, llm_config=llm_config)
user.initiate_chat(manager, message=coding_task, cache=cache)
return user.chat_messages[list(user.chat_messages.keys())[-0]]
coding_task = "Print hello world to a file called hello.txt"
user.initiate_chat(manager, message=coding_task, cache=cache)
return user.chat_messages[list(user.chat_messages.keys())[-0]]

View File

@ -17,6 +17,7 @@ import autogen
from autogen.agentchat import ConversableAgent, UserProxyAgent
from autogen.agentchat.conversable_agent import register_function
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai
@ -462,6 +463,10 @@ def test_generate_reply():
dummy_agent_2.generate_reply(messages=None, sender=dummy_agent_1)["content"] == "15"
), "generate_reply not working when messages is None"
dummy_agent_2.register_reply(["str", None], ConversableAgent.generate_oai_reply)
with pytest.raises(SenderRequired):
dummy_agent_2.generate_reply(messages=messages, sender=None)
def test_generate_reply_raises_on_messages_and_sender_none(conversable_agent):
with pytest.raises(AssertionError):
@ -1106,6 +1111,27 @@ def test_process_before_send():
print_mock.assert_called_once_with(message="hello")
def test_messages_with_carryover():
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
context = dict(message="hello", carryover="Testing carryover.")
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, str)
context = dict(message="hello", carryover=["Testing carryover.", "This should pass"])
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, str)
context = dict(message="hello", carryover=3)
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)
if __name__ == "__main__":
# test_trigger()
# test_context()

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python3 -m pytest
from typing import Any, Dict, List, Optional, Type
from autogen import AgentNameConflict, Agent, GroupChat
import pytest
from unittest import mock
import builtins
import autogen
import json
import sys
from typing import Any, Dict, List, Optional
from unittest import mock
import pytest
import autogen
from autogen import Agent, GroupChat
from autogen.exception_utils import AgentNameConflict, UndefinedNextAgent
def test_func_call_groupchat():
@ -399,34 +401,21 @@ def test_termination():
def test_next_agent():
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
)
agent3 = autogen.ConversableAgent(
"sam",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is sam speaking.",
)
agent4 = autogen.ConversableAgent(
"sally",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is sally speaking.",
)
def create_agent(name: str) -> autogen.ConversableAgent:
return autogen.ConversableAgent(
name,
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply=f"This is {name} speaking.",
)
agent1 = create_agent("alice")
agent2 = create_agent("bob")
agent3 = create_agent("sam")
agent4 = create_agent("sally")
agent5 = create_agent("samantha")
agent6 = create_agent("robert")
# Test empty is_termination_msg function
groupchat = autogen.GroupChat(
@ -448,6 +437,9 @@ def test_next_agent():
assert groupchat.next_agent(agent4, [agent1, agent3]) == agent1
assert groupchat.next_agent(agent4, [agent1, agent2, agent3]) == agent1
with pytest.raises(UndefinedNextAgent):
groupchat.next_agent(agent4, [agent5, agent6])
def test_send_intros():
agent1 = autogen.ConversableAgent(