diff --git a/examples/orchestrator.py b/examples/orchestrator.py new file mode 100644 index 000000000..840d35a39 --- /dev/null +++ b/examples/orchestrator.py @@ -0,0 +1,155 @@ +import argparse +import asyncio +import json +import logging +import os +from typing import Annotated, Callable + +import openai +from agnext.agent_components.function_executor._impl.in_process_function_executor import ( + InProcessFunctionExecutor, +) +from agnext.agent_components.model_client import OpenAI +from agnext.agent_components.types import SystemMessage +from agnext.application_components import ( + SingleThreadedAgentRuntime, +) +from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent +from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent +from agnext.chat.patterns.orchestrator_chat import OrchestratorChat +from agnext.chat.types import TextMessage +from agnext.core import Agent, AgentRuntime +from agnext.core.intervention import DefaultInterventionHandler, DropMessage +from tavily import TavilyClient +from typing_extensions import Any, override + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("agnext").setLevel(logging.DEBUG) + + +class LoggingHandler(DefaultInterventionHandler): # type: ignore + send_color = "\033[31m" + response_color = "\033[34m" + reset_color = "\033[0m" + + @override + async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: # type: ignore + if sender is None: + print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}") + else: + print( + f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}" + ) + return message + + @override + async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: # type: ignore + if recipient is None: + print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}") + else: + print( + f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}" + ) + return message + + +def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ignore + developer = ChatCompletionAgent( + name="Developer", + description="A developer that writes code.", + runtime=runtime, + system_messages=[SystemMessage("You are a Python developer.")], + model_client=OpenAI(model="gpt-4-turbo"), + ) + + tester_oai_assistant = openai.beta.assistants.create( + model="gpt-4-turbo", + description="A software tester that runs test cases and reports results.", + instructions="You are a software tester that runs test cases and reports results.", + ) + tester_oai_thread = openai.beta.threads.create() + tester = OpenAIAssistantAgent( + name="Tester", + description="A software tester that runs test cases and reports results.", + runtime=runtime, + client=openai.AsyncClient(), + assistant_id=tester_oai_assistant.id, + thread_id=tester_oai_thread.id, + ) + + def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The search results."]: + """Search the web.""" + client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) + result = client.search(query) # type: ignore + if result: + return json.dumps(result, indent=2, ensure_ascii=False) # type: ignore + return "No results found." + + function_executor = InProcessFunctionExecutor(functions=[search]) + + product_manager = ChatCompletionAgent( + name="ProductManager", + description="A product manager that performs research and comes up with specs.", + runtime=runtime, + system_messages=[ + SystemMessage("You are a product manager good at translating customer needs into software specifications."), + SystemMessage("You can use the search tool to find information on the web."), + ], + model_client=OpenAI(model="gpt-4-turbo"), + function_executor=function_executor, + ) + + planner = ChatCompletionAgent( + name="Planner", + description="A planner that organizes and schedules tasks.", + runtime=runtime, + system_messages=[SystemMessage("You are a planner of complex tasks.")], + model_client=OpenAI(model="gpt-4-turbo"), + ) + + orchestrator = ChatCompletionAgent( + name="Orchestrator", + description="An orchestrator that coordinates the team.", + runtime=runtime, + system_messages=[ + SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.") + ], + model_client=OpenAI(model="gpt-4-turbo"), + ) + + return OrchestratorChat( + "OrchestratorChat", + "A software development team.", + runtime, + orchestrator=orchestrator, + planner=planner, + specialists=[developer, product_manager, tester], + ) + + +async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore + runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) + chat = scenario(runtime) + response = runtime.send_message(TextMessage(content=message, source=user), chat) + while not response.done(): + await runtime.process_next() + print((await response).content) # type: ignore + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run a orchestrator demo.") + choices = {"software_development": software_development} + parser.add_argument( + "--scenario", + choices=list(choices.keys()), + help="The scenario to demo.", + default="software_development", + ) + parser.add_argument( + "--user", + default="Customer", + help="The user to send the message. Default is 'Customer'.", + ) + parser.add_argument("--message", help="The message to send.", required=True) + args = parser.parse_args() + asyncio.run(run(args.message, args.user, choices[args.scenario])) diff --git a/examples/patterns.py b/examples/patterns.py deleted file mode 100644 index 797b58b8a..000000000 --- a/examples/patterns.py +++ /dev/null @@ -1,268 +0,0 @@ -import argparse -import asyncio -import logging - -import openai -from agnext.agent_components.types import SystemMessage -from agnext.application_components import ( - SingleThreadedAgentRuntime, -) -from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent -from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent -from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput -from agnext.chat.patterns.orchestrator_chat import OrchestratorChat -from agnext.chat.types import TextMessage -from agnext.core._agent import Agent -from agnext.agent_components.model_client import OpenAI -from agnext.core.intervention import DefaultInterventionHandler, DropMessage -from typing_extensions import Any, override - -logging.basicConfig(level=logging.WARNING) -logging.getLogger("agnext").setLevel(logging.DEBUG) - - -class ConcatOutput(GroupChatOutput): - def __init__(self) -> None: - self._output = "" - - def on_message_received(self, message: Any) -> None: - match message: - case TextMessage(content=content): - self._output += content - case _: - ... - - def get_output(self) -> Any: - return self._output - - def reset(self) -> None: - self._output = "" - - -class LoggingHandler(DefaultInterventionHandler): - send_color = "\033[31m" - response_color = "\033[34m" - reset_color = "\033[0m" - - @override - async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: - if sender is None: - print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}") - else: - print( - f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}" - ) - return message - - @override - async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: - if recipient is None: - print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}") - else: - print( - f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}" - ) - return message - - -async def group_chat(message: str) -> None: - runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) - - joe_oai_assistant = openai.beta.assistants.create( - model="gpt-3.5-turbo", - name="Joe", - instructions="You are a commedian named Joe. Make the audience laugh.", - ) - joe_oai_thread = openai.beta.threads.create() - joe = OpenAIAssistantAgent( - name="Joe", - description="Joe the commedian.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=joe_oai_assistant.id, - thread_id=joe_oai_thread.id, - ) - - cathy_oai_assistant = openai.beta.assistants.create( - model="gpt-3.5-turbo", - name="Cathy", - instructions="You are a poet named Cathy. Write beautiful poems.", - ) - cathy_oai_thread = openai.beta.threads.create() - cathy = OpenAIAssistantAgent( - name="Cathy", - description="Cathy the poet.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=cathy_oai_assistant.id, - thread_id=cathy_oai_thread.id, - ) - - chat = GroupChat( - "Host", - "A round-robin chat room.", - runtime, - [joe, cathy], - num_rounds=5, - output=ConcatOutput(), - ) - - response = runtime.send_message(TextMessage(content=message, source="host"), chat) - - while not response.done(): - await runtime.process_next() - - await response - - -async def orchestrator_oai_assistant(message: str) -> None: - runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) - - developer_oai_assistant = openai.beta.assistants.create( - model="gpt-3.5-turbo", - name="Developer", - instructions="You are a Python developer.", - ) - developer_oai_thread = openai.beta.threads.create() - developer = OpenAIAssistantAgent( - name="Developer", - description="A developer that writes code.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=developer_oai_assistant.id, - thread_id=developer_oai_thread.id, - ) - - product_manager_oai_assistant = openai.beta.assistants.create( - model="gpt-3.5-turbo", - name="ProductManager", - instructions="You are a product manager good at translating customer needs into software specifications.", - ) - product_manager_oai_thread = openai.beta.threads.create() - product_manager = OpenAIAssistantAgent( - name="ProductManager", - description="A product manager that plans and comes up with specs.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=product_manager_oai_assistant.id, - thread_id=product_manager_oai_thread.id, - ) - - planner_oai_assistant = openai.beta.assistants.create( - model="gpt-4-turbo", - name="Planner", - instructions="You are a planner of complex tasks.", - ) - planner_oai_thread = openai.beta.threads.create() - planner = OpenAIAssistantAgent( - name="Planner", - description="A planner that organizes and schedules tasks.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=planner_oai_assistant.id, - thread_id=planner_oai_thread.id, - ) - - orchestrator_oai_assistant = openai.beta.assistants.create( - model="gpt-4-turbo", - name="Orchestrator", - instructions="You are an orchestrator that coordinates the team to complete a complex task.", - ) - orchestrator_oai_thread = openai.beta.threads.create() - orchestrator = OpenAIAssistantAgent( - name="Orchestrator", - description="An orchestrator that coordinates the team.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=orchestrator_oai_assistant.id, - thread_id=orchestrator_oai_thread.id, - ) - - chat = OrchestratorChat( - "OrchestratorChat", - "A software development team.", - runtime, - orchestrator=orchestrator, - planner=planner, - specialists=[developer, product_manager], - ) - - response = runtime.send_message(TextMessage(content=message, source="Customer"), chat) - - while not response.done(): - await runtime.process_next() - - print((await response).content) # type: ignore - - -async def orchestrator_chat_completion(message: str) -> None: - runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) - - developer = ChatCompletionAgent( - name="Developer", - description="A developer that writes code.", - runtime=runtime, - system_messages=[SystemMessage("You are a Python developer.")], - model_client=OpenAI(model="gpt-3.5-turbo"), - ) - - product_manager = ChatCompletionAgent( - name="ProductManager", - description="A product manager that plans and comes up with specs.", - runtime=runtime, - system_messages=[ - SystemMessage("You are a product manager good at translating customer needs into software specifications.") - ], - model_client=OpenAI(model="gpt-3.5-turbo"), - ) - - planner = ChatCompletionAgent( - name="Planner", - description="A planner that organizes and schedules tasks.", - runtime=runtime, - system_messages=[SystemMessage("You are a planner of complex tasks.")], - model_client=OpenAI(model="gpt-4-turbo"), - ) - - orchestrator = ChatCompletionAgent( - name="Orchestrator", - description="An orchestrator that coordinates the team.", - runtime=runtime, - system_messages=[ - SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.") - ], - model_client=OpenAI(model="gpt-4-turbo"), - ) - - chat = OrchestratorChat( - "OrchestratorChat", - "A software development team.", - runtime, - orchestrator=orchestrator, - planner=planner, - specialists=[developer, product_manager], - ) - - response = runtime.send_message(TextMessage(content=message, source="Customer"), chat) - - while not response.done(): - await runtime.process_next() - - print((await response).content) # type: ignore - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run a pattern demo.") - choices = { - "group_chat": group_chat, - "orchestrator_oai_assistant": orchestrator_oai_assistant, - "orchestrator_chat_completion": orchestrator_chat_completion, - } - parser.add_argument( - "--pattern", - choices=list(choices.keys()), - help="The pattern to demo.", - ) - parser.add_argument("--message", help="The message to send.") - args = parser.parse_args() - asyncio.run(choices[args.pattern](args.message)) diff --git a/pyproject.toml b/pyproject.toml index e79ae9535..7545797c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ disallow_untyped_decorators = true disallow_any_unimported = true [tool.pyright] -include = ["src", "examples", "tests"] +include = ["src", "tests"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false diff --git a/src/agnext/agent_components/function_executor/__init__.py b/src/agnext/agent_components/function_executor/__init__.py index 7c3a41bd8..0300488af 100644 --- a/src/agnext/agent_components/function_executor/__init__.py +++ b/src/agnext/agent_components/function_executor/__init__.py @@ -1,4 +1,9 @@ -from ._base import FunctionExecutor, FunctionInfo, into_function_definition +from ._base import Function, FunctionExecutor, into_function_signature from ._impl.in_process_function_executor import InProcessFunctionExecutor -__all__ = ["FunctionExecutor", "FunctionInfo", "into_function_definition", "InProcessFunctionExecutor"] +__all__ = [ + "FunctionExecutor", + "Function", + "into_function_signature", + "InProcessFunctionExecutor", +] diff --git a/src/agnext/agent_components/function_executor/_base.py b/src/agnext/agent_components/function_executor/_base.py index afad5d6b7..e80db24ba 100644 --- a/src/agnext/agent_components/function_executor/_base.py +++ b/src/agnext/agent_components/function_executor/_base.py @@ -4,7 +4,13 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec from typing_extensions import NotRequired, Required from ..function_utils import get_function_schema -from ..types import FunctionDefinition +from ..types import FunctionSignature + + +class Function(TypedDict): + func: Required[Callable[..., Any]] + name: NotRequired[str] + description: NotRequired[str] @runtime_checkable @@ -12,25 +18,23 @@ class FunctionExecutor(Protocol): async def execute_function(self, function_name: str, arguments: Dict[str, Any]) -> str: ... @property - def functions(self) -> Sequence[str]: ... + def functions(self) -> Sequence[Function]: ... + + @property + def function_signatures(self) -> Sequence[FunctionSignature]: + return [into_function_signature(func) for func in self.functions] -class FunctionInfo(TypedDict): - func: Required[Callable[..., Any]] - name: NotRequired[str] - description: NotRequired[str] - - -def into_function_definition( - func_info: Union[FunctionInfo, FunctionDefinition, Callable[..., Any]], -) -> FunctionDefinition: - if isinstance(func_info, FunctionDefinition): - return func_info - elif isinstance(func_info, dict): - name = func_info.get("name", func_info["func"].__name__) - description = func_info.get("description", "") - parameters = get_function_schema(func_info["func"], description="", name="")["function"]["parameters"] - return FunctionDefinition(name=name, description=description, parameters=parameters) +def into_function_signature( + func_data: Union[Function, FunctionSignature, Callable[..., Any]], +) -> FunctionSignature: + if isinstance(func_data, FunctionSignature): + return func_data + elif isinstance(func_data, dict): + name = func_data.get("name", func_data["func"].__name__) + description = func_data.get("description", "") + parameters = get_function_schema(func_data["func"], description="", name="")["function"]["parameters"] + return FunctionSignature(name=name, description=description, parameters=parameters) else: - parameters = get_function_schema(func_info, description="", name="")["function"]["parameters"] - return FunctionDefinition(name=func_info.__name__, description="", parameters=parameters) + parameters = get_function_schema(func_data, description="", name="")["function"]["parameters"] + return FunctionSignature(name=func_data.__name__, description="", parameters=parameters) diff --git a/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py b/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py index 7fddbf601..03bf69b34 100644 --- a/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py +++ b/src/agnext/agent_components/function_executor/_impl/in_process_function_executor.py @@ -1,40 +1,53 @@ import asyncio import functools from collections.abc import Sequence -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, Union -from .._base import FunctionExecutor, FunctionInfo +from .._base import Function, FunctionExecutor class InProcessFunctionExecutor(FunctionExecutor): def __init__( self, - functions: Sequence[Union[Callable[..., Any], FunctionInfo]] = [], + functions: Sequence[Union[Callable[..., Any], Function]] = [], ) -> None: - def _name(func: Union[Callable[..., Any], FunctionInfo]) -> str: + def _name(func: Union[Callable[..., Any], Function]) -> str: if isinstance(func, dict): return func.get("name", func["func"].__name__) else: return func.__name__ - def _func(func: Union[Callable[..., Any], FunctionInfo]) -> Any: + def _func(func: Union[Callable[..., Any], Function]) -> Any: if isinstance(func, dict): return func.get("func") else: return func - self._functions = dict([(_name(x), _func(x)) for x in functions]) + def _description(func: Union[Callable[..., Any], Function]) -> str: + if isinstance(func, dict): + return func.get("description", "") + else: + return "" + + self._functions: Dict[str, Function] = dict() + for func in functions: + name = _name(func) + self._functions[name] = Function( + func=_func(func), + name=name, + description=_description(func), + ) async def execute_function(self, function_name: str, arguments: dict[str, Any]) -> str: if function_name in self._functions: - function = self._functions[function_name] + function = self._functions[function_name]["func"] if asyncio.iscoroutinefunction(function): return str(function(**arguments)) else: return await asyncio.get_event_loop().run_in_executor(None, functools.partial(function, **arguments)) - raise ValueError(f"Function {function_name} not found") + raise ValueError(f"Function {function_name} not found.") @property - def functions(self) -> Sequence[str]: - return list(self._functions.keys()) + def functions(self) -> Sequence[Function]: + return list(self._functions.values()) diff --git a/src/agnext/agent_components/function_utils.py b/src/agnext/agent_components/function_utils.py index a6e1c7dc0..e76e2355e 100644 --- a/src/agnext/agent_components/function_utils.py +++ b/src/agnext/agent_components/function_utils.py @@ -1,9 +1,21 @@ -# File based from: https://github.com/microsoft/autogen/blob/main/autogen/function_utils.py +# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py # Credit to original authors import inspect from logging import getLogger -from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + ForwardRef, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) from pydantic import BaseModel, Field from typing_extensions import Annotated, Literal @@ -74,7 +86,9 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) -def get_param_annotations(typed_signature: inspect.Signature) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: +def get_param_annotations( + typed_signature: inspect.Signature, +) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]: """Get the type annotations of the parameters of a function Args: diff --git a/src/agnext/agent_components/model_client/_model_client.py b/src/agnext/agent_components/model_client/_model_client.py index 9f46b4f82..209292b6d 100644 --- a/src/agnext/agent_components/model_client/_model_client.py +++ b/src/agnext/agent_components/model_client/_model_client.py @@ -11,7 +11,7 @@ from typing_extensions import ( Union, ) -from ..types import CreateResult, FunctionDefinition, LLMMessage, RequestUsage +from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage class ModelCapabilities(TypedDict, total=False): @@ -26,7 +26,7 @@ class ModelClient(Protocol): async def create( self, messages: Sequence[LLMMessage], - functions: Sequence[FunctionDefinition] = [], + functions: Sequence[FunctionSignature] = [], # None means do not override the default # A value means to override the client default - often specified in the constructor json_output: Optional[bool] = None, @@ -36,7 +36,7 @@ class ModelClient(Protocol): def create_stream( self, messages: Sequence[LLMMessage], - functions: Sequence[FunctionDefinition] = [], + functions: Sequence[FunctionSignature] = [], # None means do not override the default # A value means to override the client default - often specified in the constructor json_output: Optional[bool] = None, diff --git a/src/agnext/agent_components/model_client/_openai_client.py b/src/agnext/agent_components/model_client/_openai_client.py index c12ad1ec7..cdea77359 100644 --- a/src/agnext/agent_components/model_client/_openai_client.py +++ b/src/agnext/agent_components/model_client/_openai_client.py @@ -40,8 +40,8 @@ from ..types import ( AssistantMessage, CreateResult, FunctionCall, - FunctionDefinition, FunctionExecutionResultMessage, + FunctionSignature, LLMMessage, RequestUsage, SystemMessage, @@ -250,7 +250,7 @@ class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False) def convert_functions( - functions: Sequence[FunctionDefinition], + functions: Sequence[FunctionSignature], ) -> List[ChatCompletionToolParam]: result: List[ChatCompletionToolParam] = [] for func in functions: @@ -304,7 +304,7 @@ class BaseOpenAI(ModelClient): async def create( self, messages: Sequence[LLMMessage], - functions: Sequence[FunctionDefinition] = [], + functions: Sequence[FunctionSignature] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, ) -> CreateResult: @@ -353,7 +353,10 @@ class BaseOpenAI(ModelClient): if result.usage is not None: logger.info( - LLMCallEvent(prompt_tokens=result.usage.prompt_tokens, completion_tokens=result.usage.completion_tokens) + LLMCallEvent( + prompt_tokens=result.usage.prompt_tokens, + completion_tokens=result.usage.completion_tokens, + ) ) usage = RequestUsage( @@ -400,7 +403,7 @@ class BaseOpenAI(ModelClient): async def create_stream( self, messages: Sequence[LLMMessage], - functions: Sequence[FunctionDefinition] = [], + functions: Sequence[FunctionSignature] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, ) -> AsyncGenerator[Union[str, CreateResult], None]: diff --git a/src/agnext/agent_components/pydantic_compat.py b/src/agnext/agent_components/pydantic_compat.py index d554d83d2..661350616 100644 --- a/src/agnext/agent_components/pydantic_compat.py +++ b/src/agnext/agent_components/pydantic_compat.py @@ -1,4 +1,4 @@ -# File based from: https://github.com/microsoft/autogen/blob/main/autogen/_pydantic.py +# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/_pydantic.py # Credit to original authors @@ -14,7 +14,9 @@ PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") def evaluate_forwardref( - value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None + value: Any, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, ) -> Any: if PYDANTIC_V1: from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal diff --git a/src/agnext/agent_components/types.py b/src/agnext/agent_components/types.py index 1a1aaf4b8..22903d6ee 100644 --- a/src/agnext/agent_components/types.py +++ b/src/agnext/agent_components/types.py @@ -18,7 +18,7 @@ class FunctionCall: @dataclass -class FunctionDefinition: +class FunctionSignature: name: str parameters: Dict[str, Any] description: str diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index 9fc4de11e..3eb9f4f90 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -1,10 +1,26 @@ -from typing import Any, Callable, Dict, List, Mapping +import asyncio +import json +from typing import Any, Coroutine, Dict, List, Mapping, Tuple +from agnext.agent_components.function_executor import FunctionExecutor from agnext.agent_components.model_client import ModelClient from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler -from agnext.agent_components.types import SystemMessage +from agnext.agent_components.types import ( + FunctionCall, + FunctionSignature, + SystemMessage, +) from agnext.chat.agents.base import BaseChatAgent -from agnext.chat.types import Message, Reset, RespondNow, ResponseFormat, TextMessage +from agnext.chat.types import ( + FunctionCallMessage, + FunctionExecutionResult, + FunctionExecutionResultMessage, + Message, + Reset, + RespondNow, + ResponseFormat, + TextMessage, +) from agnext.chat.utils import convert_messages_to_llm_messages from agnext.core import AgentRuntime, CancellationToken @@ -17,13 +33,13 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): runtime: AgentRuntime, system_messages: List[SystemMessage], model_client: ModelClient, - tools: Dict[str, Callable[..., str]] | None = None, + function_executor: FunctionExecutor | None = None, ) -> None: super().__init__(name, description, runtime) self._system_messages = system_messages self._client = model_client - self._tools = tools or {} self._chat_messages: List[Message] = [] + self._function_executor = function_executor @message_handler(TextMessage) async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: @@ -36,20 +52,112 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): self._chat_messages = [] @message_handler(RespondNow) - async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: - if message.response_format == ResponseFormat.json_object: - response_format = {"type": "json_object"} - else: - response_format = {"type": "text"} + async def on_respond_now( + self, message: RespondNow, cancellation_token: CancellationToken + ) -> TextMessage | FunctionCallMessage: + # Get function signatures. + function_signatures: List[FunctionSignature] = ( + [] if self._function_executor is None else list(self._function_executor.function_signatures) + ) + + # Get a response from the model. response = await self._client.create( self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name), - extra_create_args={"response_format": response_format}, + functions=function_signatures, + json_output=message.response_format == ResponseFormat.json_object, ) + + # If the agent has function executor, and the response is a list of + # tool calls, iterate with itself until we get a response that is not a + # list of tool calls. + while ( + self._function_executor is not None + and isinstance(response.content, list) + and all(isinstance(x, FunctionCall) for x in response.content) + ): + # Send a function call message to itself. + response = await self._send_message( + message=FunctionCallMessage(content=response.content, source=self.name), + recipient=self, + cancellation_token=cancellation_token, + ) + # Make an assistant message from the response. + response = await self._client.create( + self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name), + functions=function_signatures, + json_output=message.response_format == ResponseFormat.json_object, + ) + + final_response: Message if isinstance(response.content, str): - return TextMessage(content=response.content, source=self.name) + # If the response is a string, return a text message. + final_response = TextMessage(content=response.content, source=self.name) + elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content): + # If the response is a list of function calls, return a function call message. + final_response = FunctionCallMessage(content=response.content, source=self.name) else: raise ValueError(f"Unexpected response: {response.content}") + # Add the response to the chat messages. + self._chat_messages.append(final_response) + + # Return the response. + return final_response + + @message_handler(FunctionCallMessage) + async def on_tool_call_message( + self, message: FunctionCallMessage, cancellation_token: CancellationToken + ) -> FunctionExecutionResultMessage: + if self._function_executor is None: + raise ValueError("Function executor is not set.") + + # Add a tool call message. + self._chat_messages.append(message) + + # Execute the tool calls. + results: List[FunctionExecutionResult] = [] + execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = [] + for function_call in message.content: + # Parse the arguments. + try: + arguments = json.loads(function_call.arguments) + except json.JSONDecodeError: + results.append( + FunctionExecutionResult( + content=f"Error: Could not parse arguments for function {function_call.name}.", + call_id=function_call.id, + ) + ) + continue + # Execute the function. + future = self.execute_function(function_call.name, arguments, function_call.id) + # Append the async result. + execution_futures.append(future) + if execution_futures: + # Wait for all async results. + execution_results = await asyncio.gather(*execution_futures) + # Add the results. + for execution_result, call_id in execution_results: + results.append(FunctionExecutionResult(content=execution_result, call_id=call_id)) + + # Create a tool call result message. + tool_call_result_msg = FunctionExecutionResultMessage(content=results, source=self.name) + + # Add tool call result message. + self._chat_messages.append(tool_call_result_msg) + + # Return the results. + return tool_call_result_msg + + async def execute_function(self, name: str, args: Dict[str, Any], call_id: str) -> Tuple[str, str]: + if self._function_executor is None: + raise ValueError("Function executor is not set.") + try: + result = await self._function_executor.execute_function(name, args) + except Exception as e: + result = f"Error: {str(e)}" + return (result, call_id) + def save_state(self) -> Mapping[str, Any]: return { "description": self.description, diff --git a/src/agnext/chat/patterns/orchestrator_chat.py b/src/agnext/chat/patterns/orchestrator_chat.py index 060e334c6..6f920a16b 100644 --- a/src/agnext/chat/patterns/orchestrator_chat.py +++ b/src/agnext/chat/patterns/orchestrator_chat.py @@ -6,6 +6,8 @@ from ...core import AgentRuntime, CancellationToken from ..agents.base import BaseChatAgent from ..types import Reset, RespondNow, ResponseFormat, TextMessage +__all__ = ["OrchestratorChat"] + class OrchestratorChat(BaseChatAgent, TypeRoutedAgent): def __init__( @@ -255,15 +257,65 @@ Please output an answer in pure JSON format according to the following schema. T }} }} """.strip() - # Send a message to the orchestrator. - self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator) - # Request a response. - step_response = await self._send_message( - RespondNow(response_format=ResponseFormat.json_object), self._orchestrator - ) - # TODO: handle invalid JSON. - # TODO: use typed dictionary. - return json.loads(step_response.content) + request = step_prompt + while True: + # Send a message to the orchestrator. + self._send_message(TextMessage(content=request, source=sender), self._orchestrator) + # Request a response. + step_response = await self._send_message( + RespondNow(response_format=ResponseFormat.json_object), + self._orchestrator, + ) + # TODO: use typed dictionary. + try: + result = json.loads(str(step_response.content)) + except json.JSONDecodeError as e: + request = f"Invalid JSON: {str(e)}" + continue + if "is_request_satisfied" not in result: + request = "Missing key: is_request_satisfied" + continue + elif ( + not isinstance(result["is_request_satisfied"], dict) + or "answer" not in result["is_request_satisfied"] + or "reason" not in result["is_request_satisfied"] + ): + request = "Invalid value for key: is_request_satisfied, expected 'answer' and 'reason'" + continue + if "is_progress_being_made" not in result: + request = "Missing key: is_progress_being_made" + continue + elif ( + not isinstance(result["is_progress_being_made"], dict) + or "answer" not in result["is_progress_being_made"] + or "reason" not in result["is_progress_being_made"] + ): + request = "Invalid value for key: is_progress_being_made, expected 'answer' and 'reason'" + continue + if "next_speaker" not in result: + request = "Missing key: next_speaker" + continue + elif ( + not isinstance(result["next_speaker"], dict) + or "answer" not in result["next_speaker"] + or "reason" not in result["next_speaker"] + ): + request = "Invalid value for key: next_speaker, expected 'answer' and 'reason'" + continue + elif result["next_speaker"]["answer"] not in names: + request = f"Invalid value for key: next_speaker, expected 'answer' in {names}" + continue + if "instruction_or_question" not in result: + request = "Missing key: instruction_or_question" + continue + elif ( + not isinstance(result["instruction_or_question"], dict) + or "answer" not in result["instruction_or_question"] + or "reason" not in result["instruction_or_question"] + ): + request = "Invalid value for key: instruction_or_question, expected 'answer' and 'reason'" + continue + return result async def _rewrite_facts(self, facts: str, sender: str) -> str: new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning). @@ -293,15 +345,35 @@ Please output an answer in pure JSON format according to the following schema. T }} }} """.strip() - # Send a message to the orchestrator. - self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator) - # Request a response. - educated_guess_response = await self._send_message( - RespondNow(response_format=ResponseFormat.json_object), self._orchestrator - ) - # TODO: handle invalid JSON. - # TODO: use typed dictionary. - return json.loads(str(educated_guess_response.content)) + request = educated_guess_promt + while True: + # Send a message to the orchestrator. + self._send_message( + TextMessage(content=request, source=sender), + self._orchestrator, + ) + # Request a response. + response = await self._send_message( + RespondNow(response_format=ResponseFormat.json_object), + self._orchestrator, + ) + try: + result = json.loads(str(response.content)) + except json.JSONDecodeError as e: + request = f"Invalid JSON: {str(e)}" + continue + # TODO: use typed dictionary. + if "has_educated_guesses" not in result: + request = "Missing key: has_educated_guesses" + continue + if ( + not isinstance(result["has_educated_guesses"], dict) + or "answer" not in result["has_educated_guesses"] + or "reason" not in result["has_educated_guesses"] + ): + request = "Invalid value for key: has_educated_guesses, expected 'answer' and 'reason'" + continue + return result async def _rewrite_plan(self, team: str, sender: str) -> str: new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else. diff --git a/src/agnext/chat/utils.py b/src/agnext/chat/utils.py index 4f77c77be..35b126f29 100644 --- a/src/agnext/chat/utils.py +++ b/src/agnext/chat/utils.py @@ -2,8 +2,24 @@ from typing import List, Optional, Union from typing_extensions import Literal -from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage -from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage +from agnext.agent_components.types import ( + AssistantMessage, + LLMMessage, + UserMessage, +) +from agnext.agent_components.types import ( + FunctionExecutionResult as FunctionExecutionResultType, +) +from agnext.agent_components.types import ( + FunctionExecutionResultMessage as FunctionExecutionResultMessageType, +) +from agnext.chat.types import ( + FunctionCallMessage, + FunctionExecutionResultMessage, + Message, + MultiModalMessage, + TextMessage, +) def convert_content_message_to_assistant_message( @@ -20,7 +36,8 @@ def convert_content_message_to_assistant_message( return None elif handle_unrepresentable == "try_slice": return AssistantMessage( - content="".join([x for x in message.content if isinstance(x, str)]), source=message.source + content="".join([x for x in message.content if isinstance(x, str)]), + source=message.source, ) @@ -41,8 +58,21 @@ def convert_content_message_to_user_message( raise NotImplementedError("Sliced function calls not yet implemented") +def convert_tool_call_response_message( + message: FunctionExecutionResultMessage, + handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", +) -> Optional[FunctionExecutionResultMessageType]: + match message: + case FunctionExecutionResultMessage(): + return FunctionExecutionResultMessageType( + content=[FunctionExecutionResultType(content=x.content, call_id=x.call_id) for x in message.content] + ) + + def convert_messages_to_llm_messages( - messages: List[Message], self_name: str, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error" + messages: List[Message], + self_name: str, + handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", ) -> List[LLMMessage]: result: List[LLMMessage] = [] for message in messages: @@ -63,6 +93,10 @@ def convert_messages_to_llm_messages( converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable) if converted_message_2 is not None: result.append(converted_message_2) + case FunctionExecutionResultMessage(_, source=source) if source == self_name: + converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable) + if converted_message_3 is not None: + result.append(converted_message_3) case _: raise AssertionError("unreachable") diff --git a/tests/execution/test_user_defined_functions.py b/tests/execution/test_user_defined_functions.py index 382ecca2c..6db7ebc4f 100644 --- a/tests/execution/test_user_defined_functions.py +++ b/tests/execution/test_user_defined_functions.py @@ -3,12 +3,17 @@ import tempfile -import pytest - -from agnext.agent_components.code_executor import LocalCommandLineCodeExecutor, CodeBlock -from agnext.agent_components.func_with_reqs import FunctionWithRequirements, with_requirements - import polars +import pytest +from agnext.agent_components.code_executor import ( + CodeBlock, + LocalCommandLineCodeExecutor, +) +from agnext.agent_components.func_with_reqs import ( + FunctionWithRequirements, + with_requirements, +) + def add_two_numbers(a: int, b: int) -> int: """Add two numbers together.""" @@ -46,7 +51,9 @@ def function_missing_reqs() -> "polars.DataFrame": def test_can_load_function_with_reqs() -> None: with tempfile.TemporaryDirectory() as temp_dir: - executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data]) + executor = LocalCommandLineCodeExecutor( + work_dir=temp_dir, functions=[load_data] + ) code = f"""from {executor.functions_module} import load_data import polars @@ -65,7 +72,9 @@ print(data['name'][0])""" def test_can_load_function() -> None: with tempfile.TemporaryDirectory() as temp_dir: - executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers]) + executor = LocalCommandLineCodeExecutor( + work_dir=temp_dir, functions=[add_two_numbers] + ) code = f"""from {executor.functions_module} import add_two_numbers print(add_two_numbers(1, 2))""" @@ -80,7 +89,9 @@ print(add_two_numbers(1, 2))""" def test_fails_for_function_incorrect_import() -> None: with tempfile.TemporaryDirectory() as temp_dir: - executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_import]) + executor = LocalCommandLineCodeExecutor( + work_dir=temp_dir, functions=[function_incorrect_import] + ) code = f"""from {executor.functions_module} import function_incorrect_import function_incorrect_import()""" @@ -94,7 +105,9 @@ function_incorrect_import()""" def test_fails_for_function_incorrect_dep() -> None: with tempfile.TemporaryDirectory() as temp_dir: - executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_dep]) + executor = LocalCommandLineCodeExecutor( + work_dir=temp_dir, functions=[function_incorrect_dep] + ) code = f"""from {executor.functions_module} import function_incorrect_dep function_incorrect_dep()""" @@ -106,10 +119,11 @@ function_incorrect_dep()""" ) - def test_formatted_prompt() -> None: with tempfile.TemporaryDirectory() as temp_dir: - executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers]) + executor = LocalCommandLineCodeExecutor( + work_dir=temp_dir, functions=[add_two_numbers] + ) result = executor.format_functions_for_prompt() assert ( @@ -140,7 +154,6 @@ def add_two_numbers(a: int, b: int) -> int: ) - def test_can_load_str_function_with_reqs() -> None: with tempfile.TemporaryDirectory() as temp_dir: func = FunctionWithRequirements.from_str(