add tool call to chat completion agent (#35)

* add tool call to chat completion agent

* refactor function executor; tool executor in chat completion agent

* update example

* update orchestrator chat demo

* handle function execution result message type

* format

* temp fix for examples.

* fix

* update chat completion agent
This commit is contained in:
Eric Zhu 2024-05-30 09:01:35 -07:00 committed by GitHub
parent 2dc7af87ef
commit 04c30596ed
15 changed files with 516 additions and 361 deletions

155
examples/orchestrator.py Normal file
View File

@ -0,0 +1,155 @@
import argparse
import asyncio
import json
import logging
import os
from typing import Annotated, Callable
import openai
from agnext.agent_components.function_executor._impl.in_process_function_executor import (
InProcessFunctionExecutor,
)
from agnext.agent_components.model_client import OpenAI
from agnext.agent_components.types import SystemMessage
from agnext.application_components import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
from agnext.chat.types import TextMessage
from agnext.core import Agent, AgentRuntime
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from tavily import TavilyClient
from typing_extensions import Any, override
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
class LoggingHandler(DefaultInterventionHandler): # type: ignore
send_color = "\033[31m"
response_color = "\033[34m"
reset_color = "\033[0m"
@override
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: # type: ignore
if sender is None:
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
else:
print(
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
)
return message
@override
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: # type: ignore
if recipient is None:
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
else:
print(
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
)
return message
def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ignore
developer = ChatCompletionAgent(
name="Developer",
description="A developer that writes code.",
runtime=runtime,
system_messages=[SystemMessage("You are a Python developer.")],
model_client=OpenAI(model="gpt-4-turbo"),
)
tester_oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="A software tester that runs test cases and reports results.",
instructions="You are a software tester that runs test cases and reports results.",
)
tester_oai_thread = openai.beta.threads.create()
tester = OpenAIAssistantAgent(
name="Tester",
description="A software tester that runs test cases and reports results.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=tester_oai_assistant.id,
thread_id=tester_oai_thread.id,
)
def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The search results."]:
"""Search the web."""
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
result = client.search(query) # type: ignore
if result:
return json.dumps(result, indent=2, ensure_ascii=False) # type: ignore
return "No results found."
function_executor = InProcessFunctionExecutor(functions=[search])
product_manager = ChatCompletionAgent(
name="ProductManager",
description="A product manager that performs research and comes up with specs.",
runtime=runtime,
system_messages=[
SystemMessage("You are a product manager good at translating customer needs into software specifications."),
SystemMessage("You can use the search tool to find information on the web."),
],
model_client=OpenAI(model="gpt-4-turbo"),
function_executor=function_executor,
)
planner = ChatCompletionAgent(
name="Planner",
description="A planner that organizes and schedules tasks.",
runtime=runtime,
system_messages=[SystemMessage("You are a planner of complex tasks.")],
model_client=OpenAI(model="gpt-4-turbo"),
)
orchestrator = ChatCompletionAgent(
name="Orchestrator",
description="An orchestrator that coordinates the team.",
runtime=runtime,
system_messages=[
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
],
model_client=OpenAI(model="gpt-4-turbo"),
)
return OrchestratorChat(
"OrchestratorChat",
"A software development team.",
runtime,
orchestrator=orchestrator,
planner=planner,
specialists=[developer, product_manager, tester],
)
async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
chat = scenario(runtime)
response = runtime.send_message(TextMessage(content=message, source=user), chat)
while not response.done():
await runtime.process_next()
print((await response).content) # type: ignore
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a orchestrator demo.")
choices = {"software_development": software_development}
parser.add_argument(
"--scenario",
choices=list(choices.keys()),
help="The scenario to demo.",
default="software_development",
)
parser.add_argument(
"--user",
default="Customer",
help="The user to send the message. Default is 'Customer'.",
)
parser.add_argument("--message", help="The message to send.", required=True)
args = parser.parse_args()
asyncio.run(run(args.message, args.user, choices[args.scenario]))

View File

@ -1,268 +0,0 @@
import argparse
import asyncio
import logging
import openai
from agnext.agent_components.types import SystemMessage
from agnext.application_components import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
from agnext.chat.types import TextMessage
from agnext.core._agent import Agent
from agnext.agent_components.model_client import OpenAI
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from typing_extensions import Any, override
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
class ConcatOutput(GroupChatOutput):
def __init__(self) -> None:
self._output = ""
def on_message_received(self, message: Any) -> None:
match message:
case TextMessage(content=content):
self._output += content
case _:
...
def get_output(self) -> Any:
return self._output
def reset(self) -> None:
self._output = ""
class LoggingHandler(DefaultInterventionHandler):
send_color = "\033[31m"
response_color = "\033[34m"
reset_color = "\033[0m"
@override
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
if sender is None:
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
else:
print(
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
)
return message
@override
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
if recipient is None:
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
else:
print(
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
)
return message
async def group_chat(message: str) -> None:
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
joe_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
name="Joe",
instructions="You are a commedian named Joe. Make the audience laugh.",
)
joe_oai_thread = openai.beta.threads.create()
joe = OpenAIAssistantAgent(
name="Joe",
description="Joe the commedian.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=joe_oai_assistant.id,
thread_id=joe_oai_thread.id,
)
cathy_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
name="Cathy",
instructions="You are a poet named Cathy. Write beautiful poems.",
)
cathy_oai_thread = openai.beta.threads.create()
cathy = OpenAIAssistantAgent(
name="Cathy",
description="Cathy the poet.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=cathy_oai_assistant.id,
thread_id=cathy_oai_thread.id,
)
chat = GroupChat(
"Host",
"A round-robin chat room.",
runtime,
[joe, cathy],
num_rounds=5,
output=ConcatOutput(),
)
response = runtime.send_message(TextMessage(content=message, source="host"), chat)
while not response.done():
await runtime.process_next()
await response
async def orchestrator_oai_assistant(message: str) -> None:
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
developer_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
name="Developer",
instructions="You are a Python developer.",
)
developer_oai_thread = openai.beta.threads.create()
developer = OpenAIAssistantAgent(
name="Developer",
description="A developer that writes code.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=developer_oai_assistant.id,
thread_id=developer_oai_thread.id,
)
product_manager_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
name="ProductManager",
instructions="You are a product manager good at translating customer needs into software specifications.",
)
product_manager_oai_thread = openai.beta.threads.create()
product_manager = OpenAIAssistantAgent(
name="ProductManager",
description="A product manager that plans and comes up with specs.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=product_manager_oai_assistant.id,
thread_id=product_manager_oai_thread.id,
)
planner_oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
name="Planner",
instructions="You are a planner of complex tasks.",
)
planner_oai_thread = openai.beta.threads.create()
planner = OpenAIAssistantAgent(
name="Planner",
description="A planner that organizes and schedules tasks.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=planner_oai_assistant.id,
thread_id=planner_oai_thread.id,
)
orchestrator_oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
name="Orchestrator",
instructions="You are an orchestrator that coordinates the team to complete a complex task.",
)
orchestrator_oai_thread = openai.beta.threads.create()
orchestrator = OpenAIAssistantAgent(
name="Orchestrator",
description="An orchestrator that coordinates the team.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=orchestrator_oai_assistant.id,
thread_id=orchestrator_oai_thread.id,
)
chat = OrchestratorChat(
"OrchestratorChat",
"A software development team.",
runtime,
orchestrator=orchestrator,
planner=planner,
specialists=[developer, product_manager],
)
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
while not response.done():
await runtime.process_next()
print((await response).content) # type: ignore
async def orchestrator_chat_completion(message: str) -> None:
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
developer = ChatCompletionAgent(
name="Developer",
description="A developer that writes code.",
runtime=runtime,
system_messages=[SystemMessage("You are a Python developer.")],
model_client=OpenAI(model="gpt-3.5-turbo"),
)
product_manager = ChatCompletionAgent(
name="ProductManager",
description="A product manager that plans and comes up with specs.",
runtime=runtime,
system_messages=[
SystemMessage("You are a product manager good at translating customer needs into software specifications.")
],
model_client=OpenAI(model="gpt-3.5-turbo"),
)
planner = ChatCompletionAgent(
name="Planner",
description="A planner that organizes and schedules tasks.",
runtime=runtime,
system_messages=[SystemMessage("You are a planner of complex tasks.")],
model_client=OpenAI(model="gpt-4-turbo"),
)
orchestrator = ChatCompletionAgent(
name="Orchestrator",
description="An orchestrator that coordinates the team.",
runtime=runtime,
system_messages=[
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
],
model_client=OpenAI(model="gpt-4-turbo"),
)
chat = OrchestratorChat(
"OrchestratorChat",
"A software development team.",
runtime,
orchestrator=orchestrator,
planner=planner,
specialists=[developer, product_manager],
)
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
while not response.done():
await runtime.process_next()
print((await response).content) # type: ignore
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a pattern demo.")
choices = {
"group_chat": group_chat,
"orchestrator_oai_assistant": orchestrator_oai_assistant,
"orchestrator_chat_completion": orchestrator_chat_completion,
}
parser.add_argument(
"--pattern",
choices=list(choices.keys()),
help="The pattern to demo.",
)
parser.add_argument("--message", help="The message to send.")
args = parser.parse_args()
asyncio.run(choices[args.pattern](args.message))

View File

@ -72,7 +72,7 @@ disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.pyright]
include = ["src", "examples", "tests"]
include = ["src", "tests"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false

View File

@ -1,4 +1,9 @@
from ._base import FunctionExecutor, FunctionInfo, into_function_definition
from ._base import Function, FunctionExecutor, into_function_signature
from ._impl.in_process_function_executor import InProcessFunctionExecutor
__all__ = ["FunctionExecutor", "FunctionInfo", "into_function_definition", "InProcessFunctionExecutor"]
__all__ = [
"FunctionExecutor",
"Function",
"into_function_signature",
"InProcessFunctionExecutor",
]

View File

@ -4,7 +4,13 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec
from typing_extensions import NotRequired, Required
from ..function_utils import get_function_schema
from ..types import FunctionDefinition
from ..types import FunctionSignature
class Function(TypedDict):
func: Required[Callable[..., Any]]
name: NotRequired[str]
description: NotRequired[str]
@runtime_checkable
@ -12,25 +18,23 @@ class FunctionExecutor(Protocol):
async def execute_function(self, function_name: str, arguments: Dict[str, Any]) -> str: ...
@property
def functions(self) -> Sequence[str]: ...
def functions(self) -> Sequence[Function]: ...
@property
def function_signatures(self) -> Sequence[FunctionSignature]:
return [into_function_signature(func) for func in self.functions]
class FunctionInfo(TypedDict):
func: Required[Callable[..., Any]]
name: NotRequired[str]
description: NotRequired[str]
def into_function_definition(
func_info: Union[FunctionInfo, FunctionDefinition, Callable[..., Any]],
) -> FunctionDefinition:
if isinstance(func_info, FunctionDefinition):
return func_info
elif isinstance(func_info, dict):
name = func_info.get("name", func_info["func"].__name__)
description = func_info.get("description", "")
parameters = get_function_schema(func_info["func"], description="", name="")["function"]["parameters"]
return FunctionDefinition(name=name, description=description, parameters=parameters)
def into_function_signature(
func_data: Union[Function, FunctionSignature, Callable[..., Any]],
) -> FunctionSignature:
if isinstance(func_data, FunctionSignature):
return func_data
elif isinstance(func_data, dict):
name = func_data.get("name", func_data["func"].__name__)
description = func_data.get("description", "")
parameters = get_function_schema(func_data["func"], description="", name="")["function"]["parameters"]
return FunctionSignature(name=name, description=description, parameters=parameters)
else:
parameters = get_function_schema(func_info, description="", name="")["function"]["parameters"]
return FunctionDefinition(name=func_info.__name__, description="", parameters=parameters)
parameters = get_function_schema(func_data, description="", name="")["function"]["parameters"]
return FunctionSignature(name=func_data.__name__, description="", parameters=parameters)

View File

@ -1,40 +1,53 @@
import asyncio
import functools
from collections.abc import Sequence
from typing import Any, Callable, Union
from typing import Any, Callable, Dict, Union
from .._base import FunctionExecutor, FunctionInfo
from .._base import Function, FunctionExecutor
class InProcessFunctionExecutor(FunctionExecutor):
def __init__(
self,
functions: Sequence[Union[Callable[..., Any], FunctionInfo]] = [],
functions: Sequence[Union[Callable[..., Any], Function]] = [],
) -> None:
def _name(func: Union[Callable[..., Any], FunctionInfo]) -> str:
def _name(func: Union[Callable[..., Any], Function]) -> str:
if isinstance(func, dict):
return func.get("name", func["func"].__name__)
else:
return func.__name__
def _func(func: Union[Callable[..., Any], FunctionInfo]) -> Any:
def _func(func: Union[Callable[..., Any], Function]) -> Any:
if isinstance(func, dict):
return func.get("func")
else:
return func
self._functions = dict([(_name(x), _func(x)) for x in functions])
def _description(func: Union[Callable[..., Any], Function]) -> str:
if isinstance(func, dict):
return func.get("description", "")
else:
return ""
self._functions: Dict[str, Function] = dict()
for func in functions:
name = _name(func)
self._functions[name] = Function(
func=_func(func),
name=name,
description=_description(func),
)
async def execute_function(self, function_name: str, arguments: dict[str, Any]) -> str:
if function_name in self._functions:
function = self._functions[function_name]
function = self._functions[function_name]["func"]
if asyncio.iscoroutinefunction(function):
return str(function(**arguments))
else:
return await asyncio.get_event_loop().run_in_executor(None, functools.partial(function, **arguments))
raise ValueError(f"Function {function_name} not found")
raise ValueError(f"Function {function_name} not found.")
@property
def functions(self) -> Sequence[str]:
return list(self._functions.keys())
def functions(self) -> Sequence[Function]:
return list(self._functions.values())

View File

@ -1,9 +1,21 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/function_utils.py
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
# Credit to original authors
import inspect
from logging import getLogger
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
ForwardRef,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Literal
@ -74,7 +86,9 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
def get_param_annotations(
typed_signature: inspect.Signature,
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:

View File

@ -11,7 +11,7 @@ from typing_extensions import (
Union,
)
from ..types import CreateResult, FunctionDefinition, LLMMessage, RequestUsage
from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage
class ModelCapabilities(TypedDict, total=False):
@ -26,7 +26,7 @@ class ModelClient(Protocol):
async def create(
self,
messages: Sequence[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
functions: Sequence[FunctionSignature] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
@ -36,7 +36,7 @@ class ModelClient(Protocol):
def create_stream(
self,
messages: Sequence[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
functions: Sequence[FunctionSignature] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,

View File

@ -40,8 +40,8 @@ from ..types import (
AssistantMessage,
CreateResult,
FunctionCall,
FunctionDefinition,
FunctionExecutionResultMessage,
FunctionSignature,
LLMMessage,
RequestUsage,
SystemMessage,
@ -250,7 +250,7 @@ class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False)
def convert_functions(
functions: Sequence[FunctionDefinition],
functions: Sequence[FunctionSignature],
) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for func in functions:
@ -304,7 +304,7 @@ class BaseOpenAI(ModelClient):
async def create(
self,
messages: Sequence[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
functions: Sequence[FunctionSignature] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult:
@ -353,7 +353,10 @@ class BaseOpenAI(ModelClient):
if result.usage is not None:
logger.info(
LLMCallEvent(prompt_tokens=result.usage.prompt_tokens, completion_tokens=result.usage.completion_tokens)
LLMCallEvent(
prompt_tokens=result.usage.prompt_tokens,
completion_tokens=result.usage.completion_tokens,
)
)
usage = RequestUsage(
@ -400,7 +403,7 @@ class BaseOpenAI(ModelClient):
async def create_stream(
self,
messages: Sequence[LLMMessage],
functions: Sequence[FunctionDefinition] = [],
functions: Sequence[FunctionSignature] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@ -1,4 +1,4 @@
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/_pydantic.py
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/_pydantic.py
# Credit to original authors
@ -14,7 +14,9 @@ PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
def evaluate_forwardref(
value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
value: Any,
globalns: dict[str, Any] | None = None,
localns: dict[str, Any] | None = None,
) -> Any:
if PYDANTIC_V1:
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal

View File

@ -18,7 +18,7 @@ class FunctionCall:
@dataclass
class FunctionDefinition:
class FunctionSignature:
name: str
parameters: Dict[str, Any]
description: str

View File

@ -1,10 +1,26 @@
from typing import Any, Callable, Dict, List, Mapping
import asyncio
import json
from typing import Any, Coroutine, Dict, List, Mapping, Tuple
from agnext.agent_components.function_executor import FunctionExecutor
from agnext.agent_components.model_client import ModelClient
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.agent_components.types import SystemMessage
from agnext.agent_components.types import (
FunctionCall,
FunctionSignature,
SystemMessage,
)
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Message, Reset, RespondNow, ResponseFormat, TextMessage
from agnext.chat.types import (
FunctionCallMessage,
FunctionExecutionResult,
FunctionExecutionResultMessage,
Message,
Reset,
RespondNow,
ResponseFormat,
TextMessage,
)
from agnext.chat.utils import convert_messages_to_llm_messages
from agnext.core import AgentRuntime, CancellationToken
@ -17,13 +33,13 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
runtime: AgentRuntime,
system_messages: List[SystemMessage],
model_client: ModelClient,
tools: Dict[str, Callable[..., str]] | None = None,
function_executor: FunctionExecutor | None = None,
) -> None:
super().__init__(name, description, runtime)
self._system_messages = system_messages
self._client = model_client
self._tools = tools or {}
self._chat_messages: List[Message] = []
self._function_executor = function_executor
@message_handler(TextMessage)
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
@ -36,20 +52,112 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
self._chat_messages = []
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
if message.response_format == ResponseFormat.json_object:
response_format = {"type": "json_object"}
else:
response_format = {"type": "text"}
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
# Get function signatures.
function_signatures: List[FunctionSignature] = (
[] if self._function_executor is None else list(self._function_executor.function_signatures)
)
# Get a response from the model.
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
extra_create_args={"response_format": response_format},
functions=function_signatures,
json_output=message.response_format == ResponseFormat.json_object,
)
# If the agent has function executor, and the response is a list of
# tool calls, iterate with itself until we get a response that is not a
# list of tool calls.
while (
self._function_executor is not None
and isinstance(response.content, list)
and all(isinstance(x, FunctionCall) for x in response.content)
):
# Send a function call message to itself.
response = await self._send_message(
message=FunctionCallMessage(content=response.content, source=self.name),
recipient=self,
cancellation_token=cancellation_token,
)
# Make an assistant message from the response.
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
functions=function_signatures,
json_output=message.response_format == ResponseFormat.json_object,
)
final_response: Message
if isinstance(response.content, str):
return TextMessage(content=response.content, source=self.name)
# If the response is a string, return a text message.
final_response = TextMessage(content=response.content, source=self.name)
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
# If the response is a list of function calls, return a function call message.
final_response = FunctionCallMessage(content=response.content, source=self.name)
else:
raise ValueError(f"Unexpected response: {response.content}")
# Add the response to the chat messages.
self._chat_messages.append(final_response)
# Return the response.
return final_response
@message_handler(FunctionCallMessage)
async def on_tool_call_message(
self, message: FunctionCallMessage, cancellation_token: CancellationToken
) -> FunctionExecutionResultMessage:
if self._function_executor is None:
raise ValueError("Function executor is not set.")
# Add a tool call message.
self._chat_messages.append(message)
# Execute the tool calls.
results: List[FunctionExecutionResult] = []
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
for function_call in message.content:
# Parse the arguments.
try:
arguments = json.loads(function_call.arguments)
except json.JSONDecodeError:
results.append(
FunctionExecutionResult(
content=f"Error: Could not parse arguments for function {function_call.name}.",
call_id=function_call.id,
)
)
continue
# Execute the function.
future = self.execute_function(function_call.name, arguments, function_call.id)
# Append the async result.
execution_futures.append(future)
if execution_futures:
# Wait for all async results.
execution_results = await asyncio.gather(*execution_futures)
# Add the results.
for execution_result, call_id in execution_results:
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
# Create a tool call result message.
tool_call_result_msg = FunctionExecutionResultMessage(content=results, source=self.name)
# Add tool call result message.
self._chat_messages.append(tool_call_result_msg)
# Return the results.
return tool_call_result_msg
async def execute_function(self, name: str, args: Dict[str, Any], call_id: str) -> Tuple[str, str]:
if self._function_executor is None:
raise ValueError("Function executor is not set.")
try:
result = await self._function_executor.execute_function(name, args)
except Exception as e:
result = f"Error: {str(e)}"
return (result, call_id)
def save_state(self) -> Mapping[str, Any]:
return {
"description": self.description,

View File

@ -6,6 +6,8 @@ from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
__all__ = ["OrchestratorChat"]
class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
def __init__(
@ -255,15 +257,65 @@ Please output an answer in pure JSON format according to the following schema. T
}}
}}
""".strip()
# Send a message to the orchestrator.
self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator)
# Request a response.
step_response = await self._send_message(
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
)
# TODO: handle invalid JSON.
# TODO: use typed dictionary.
return json.loads(step_response.content)
request = step_prompt
while True:
# Send a message to the orchestrator.
self._send_message(TextMessage(content=request, source=sender), self._orchestrator)
# Request a response.
step_response = await self._send_message(
RespondNow(response_format=ResponseFormat.json_object),
self._orchestrator,
)
# TODO: use typed dictionary.
try:
result = json.loads(str(step_response.content))
except json.JSONDecodeError as e:
request = f"Invalid JSON: {str(e)}"
continue
if "is_request_satisfied" not in result:
request = "Missing key: is_request_satisfied"
continue
elif (
not isinstance(result["is_request_satisfied"], dict)
or "answer" not in result["is_request_satisfied"]
or "reason" not in result["is_request_satisfied"]
):
request = "Invalid value for key: is_request_satisfied, expected 'answer' and 'reason'"
continue
if "is_progress_being_made" not in result:
request = "Missing key: is_progress_being_made"
continue
elif (
not isinstance(result["is_progress_being_made"], dict)
or "answer" not in result["is_progress_being_made"]
or "reason" not in result["is_progress_being_made"]
):
request = "Invalid value for key: is_progress_being_made, expected 'answer' and 'reason'"
continue
if "next_speaker" not in result:
request = "Missing key: next_speaker"
continue
elif (
not isinstance(result["next_speaker"], dict)
or "answer" not in result["next_speaker"]
or "reason" not in result["next_speaker"]
):
request = "Invalid value for key: next_speaker, expected 'answer' and 'reason'"
continue
elif result["next_speaker"]["answer"] not in names:
request = f"Invalid value for key: next_speaker, expected 'answer' in {names}"
continue
if "instruction_or_question" not in result:
request = "Missing key: instruction_or_question"
continue
elif (
not isinstance(result["instruction_or_question"], dict)
or "answer" not in result["instruction_or_question"]
or "reason" not in result["instruction_or_question"]
):
request = "Invalid value for key: instruction_or_question, expected 'answer' and 'reason'"
continue
return result
async def _rewrite_facts(self, facts: str, sender: str) -> str:
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
@ -293,15 +345,35 @@ Please output an answer in pure JSON format according to the following schema. T
}}
}}
""".strip()
# Send a message to the orchestrator.
self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator)
# Request a response.
educated_guess_response = await self._send_message(
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
)
# TODO: handle invalid JSON.
# TODO: use typed dictionary.
return json.loads(str(educated_guess_response.content))
request = educated_guess_promt
while True:
# Send a message to the orchestrator.
self._send_message(
TextMessage(content=request, source=sender),
self._orchestrator,
)
# Request a response.
response = await self._send_message(
RespondNow(response_format=ResponseFormat.json_object),
self._orchestrator,
)
try:
result = json.loads(str(response.content))
except json.JSONDecodeError as e:
request = f"Invalid JSON: {str(e)}"
continue
# TODO: use typed dictionary.
if "has_educated_guesses" not in result:
request = "Missing key: has_educated_guesses"
continue
if (
not isinstance(result["has_educated_guesses"], dict)
or "answer" not in result["has_educated_guesses"]
or "reason" not in result["has_educated_guesses"]
):
request = "Invalid value for key: has_educated_guesses, expected 'answer' and 'reason'"
continue
return result
async def _rewrite_plan(self, team: str, sender: str) -> str:
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.

View File

@ -2,8 +2,24 @@ from typing import List, Optional, Union
from typing_extensions import Literal
from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage
from agnext.agent_components.types import (
AssistantMessage,
LLMMessage,
UserMessage,
)
from agnext.agent_components.types import (
FunctionExecutionResult as FunctionExecutionResultType,
)
from agnext.agent_components.types import (
FunctionExecutionResultMessage as FunctionExecutionResultMessageType,
)
from agnext.chat.types import (
FunctionCallMessage,
FunctionExecutionResultMessage,
Message,
MultiModalMessage,
TextMessage,
)
def convert_content_message_to_assistant_message(
@ -20,7 +36,8 @@ def convert_content_message_to_assistant_message(
return None
elif handle_unrepresentable == "try_slice":
return AssistantMessage(
content="".join([x for x in message.content if isinstance(x, str)]), source=message.source
content="".join([x for x in message.content if isinstance(x, str)]),
source=message.source,
)
@ -41,8 +58,21 @@ def convert_content_message_to_user_message(
raise NotImplementedError("Sliced function calls not yet implemented")
def convert_tool_call_response_message(
message: FunctionExecutionResultMessage,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[FunctionExecutionResultMessageType]:
match message:
case FunctionExecutionResultMessage():
return FunctionExecutionResultMessageType(
content=[FunctionExecutionResultType(content=x.content, call_id=x.call_id) for x in message.content]
)
def convert_messages_to_llm_messages(
messages: List[Message], self_name: str, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error"
messages: List[Message],
self_name: str,
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> List[LLMMessage]:
result: List[LLMMessage] = []
for message in messages:
@ -63,6 +93,10 @@ def convert_messages_to_llm_messages(
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case FunctionExecutionResultMessage(_, source=source) if source == self_name:
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
if converted_message_3 is not None:
result.append(converted_message_3)
case _:
raise AssertionError("unreachable")

View File

@ -3,12 +3,17 @@
import tempfile
import pytest
from agnext.agent_components.code_executor import LocalCommandLineCodeExecutor, CodeBlock
from agnext.agent_components.func_with_reqs import FunctionWithRequirements, with_requirements
import polars
import pytest
from agnext.agent_components.code_executor import (
CodeBlock,
LocalCommandLineCodeExecutor,
)
from agnext.agent_components.func_with_reqs import (
FunctionWithRequirements,
with_requirements,
)
def add_two_numbers(a: int, b: int) -> int:
"""Add two numbers together."""
@ -46,7 +51,9 @@ def function_missing_reqs() -> "polars.DataFrame":
def test_can_load_function_with_reqs() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data])
executor = LocalCommandLineCodeExecutor(
work_dir=temp_dir, functions=[load_data]
)
code = f"""from {executor.functions_module} import load_data
import polars
@ -65,7 +72,9 @@ print(data['name'][0])"""
def test_can_load_function() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers])
executor = LocalCommandLineCodeExecutor(
work_dir=temp_dir, functions=[add_two_numbers]
)
code = f"""from {executor.functions_module} import add_two_numbers
print(add_two_numbers(1, 2))"""
@ -80,7 +89,9 @@ print(add_two_numbers(1, 2))"""
def test_fails_for_function_incorrect_import() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_import])
executor = LocalCommandLineCodeExecutor(
work_dir=temp_dir, functions=[function_incorrect_import]
)
code = f"""from {executor.functions_module} import function_incorrect_import
function_incorrect_import()"""
@ -94,7 +105,9 @@ function_incorrect_import()"""
def test_fails_for_function_incorrect_dep() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_dep])
executor = LocalCommandLineCodeExecutor(
work_dir=temp_dir, functions=[function_incorrect_dep]
)
code = f"""from {executor.functions_module} import function_incorrect_dep
function_incorrect_dep()"""
@ -106,10 +119,11 @@ function_incorrect_dep()"""
)
def test_formatted_prompt() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers])
executor = LocalCommandLineCodeExecutor(
work_dir=temp_dir, functions=[add_two_numbers]
)
result = executor.format_functions_for_prompt()
assert (
@ -140,7 +154,6 @@ def add_two_numbers(a: int, b: int) -> int:
)
def test_can_load_str_function_with_reqs() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
func = FunctionWithRequirements.from_str(