chat completion agent and state (#32)

* chat completion agent

* use response format enum; use Message type for chat history; remove name from state
This commit is contained in:
Eric Zhu 2024-05-28 23:18:28 -07:00 committed by GitHub
parent e3a2f79e65
commit cd147b6eed
6 changed files with 176 additions and 37 deletions

View File

@ -3,14 +3,17 @@ import asyncio
import logging
import openai
from agnext.agent_components.types import SystemMessage
from agnext.application_components import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
from agnext.chat.types import TextMessage
from agnext.core._agent import Agent
from agnext.agent_components.model_client import OpenAI
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from typing_extensions import Any, override
@ -95,7 +98,14 @@ async def group_chat(message: str) -> None:
thread_id=cathy_oai_thread.id,
)
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
chat = GroupChat(
"Host",
"A round-robin chat room.",
runtime,
[joe, cathy],
num_rounds=5,
output=ConcatOutput(),
)
response = runtime.send_message(TextMessage(content=message, source="host"), chat)
@ -105,7 +115,7 @@ async def group_chat(message: str) -> None:
await response
async def orchestrator(message: str) -> None:
async def orchestrator_oai_assistant(message: str) -> None:
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
developer_oai_assistant = openai.beta.assistants.create(
@ -169,7 +179,63 @@ async def orchestrator(message: str) -> None:
)
chat = OrchestratorChat(
"Orchestrator Chat",
"OrchestratorChat",
"A software development team.",
runtime,
orchestrator=orchestrator,
planner=planner,
specialists=[developer, product_manager],
)
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
while not response.done():
await runtime.process_next()
print((await response).content) # type: ignore
async def orchestrator_chat_completion(message: str) -> None:
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
developer = ChatCompletionAgent(
name="Developer",
description="A developer that writes code.",
runtime=runtime,
system_messages=[SystemMessage("You are a Python developer.")],
model_client=OpenAI(model="gpt-3.5-turbo"),
)
product_manager = ChatCompletionAgent(
name="ProductManager",
description="A product manager that plans and comes up with specs.",
runtime=runtime,
system_messages=[
SystemMessage("You are a product manager good at translating customer needs into software specifications.")
],
model_client=OpenAI(model="gpt-3.5-turbo"),
)
planner = ChatCompletionAgent(
name="Planner",
description="A planner that organizes and schedules tasks.",
runtime=runtime,
system_messages=[SystemMessage("You are a planner of complex tasks.")],
model_client=OpenAI(model="gpt-4-turbo"),
)
orchestrator = ChatCompletionAgent(
name="Orchestrator",
description="An orchestrator that coordinates the team.",
runtime=runtime,
system_messages=[
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
],
model_client=OpenAI(model="gpt-4-turbo"),
)
chat = OrchestratorChat(
"OrchestratorChat",
"A software development team.",
runtime,
orchestrator=orchestrator,
@ -187,18 +253,16 @@ async def orchestrator(message: str) -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a pattern demo.")
chocies = ["group_chat", "orchestrator"]
choices = {
"group_chat": group_chat,
"orchestrator_oai_assistant": orchestrator_oai_assistant,
"orchestrator_chat_completion": orchestrator_chat_completion,
}
parser.add_argument(
"--pattern",
choices=chocies,
choices=list(choices.keys()),
help="The pattern to demo.",
)
parser.add_argument("--message", help="The message to send.")
args = parser.parse_args()
if args.pattern == "group_chat":
asyncio.run(group_chat(args.message))
elif args.pattern == "orchestrator":
asyncio.run(orchestrator(args.message))
else:
raise ValueError(f"Invalid pattern: {args.pattern}")
asyncio.run(choices[args.pattern](args.message))

View File

@ -0,0 +1,63 @@
from typing import Any, Callable, Dict, List, Mapping
from agnext.agent_components.model_client import ModelClient
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.agent_components.types import SystemMessage
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Message, Reset, RespondNow, ResponseFormat, TextMessage
from agnext.chat.utils import convert_messages_to_llm_messages
from agnext.core import AgentRuntime, CancellationToken
class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
system_messages: List[SystemMessage],
model_client: ModelClient,
tools: Dict[str, Callable[..., str]] | None = None,
) -> None:
super().__init__(name, description, runtime)
self._system_messages = system_messages
self._client = model_client
self._tools = tools or {}
self._chat_messages: List[Message] = []
@message_handler(TextMessage)
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Add a user message.
self._chat_messages.append(message)
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the chat messages.
self._chat_messages = []
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
if message.response_format == ResponseFormat.json_object:
response_format = {"type": "json_object"}
else:
response_format = {"type": "text"}
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
extra_create_args={"response_format": response_format},
)
if isinstance(response.content, str):
return TextMessage(content=response.content, source=self.name)
else:
raise ValueError(f"Unexpected response: {response.content}")
def save_state(self) -> Mapping[str, Any]:
return {
"description": self.description,
"chat_messages": self._chat_messages,
"system_messages": self._system_messages,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._chat_messages = state["chat_messages"]
self._system_messages = state["system_messages"]
self._description = state["description"]

View File

@ -1,10 +1,11 @@
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List, Mapping
import openai
from openai.types.beta import AssistantResponseFormatParam
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Reset, RespondNow, TextMessage
from agnext.chat.types import Reset, RespondNow, ResponseFormat, TextMessage
from agnext.core import AgentRuntime, CancellationToken
@ -56,12 +57,16 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Handle response format.
if message.response_format == ResponseFormat.json_object:
response_format = AssistantResponseFormatParam(type="json_object")
else:
response_format = AssistantResponseFormatParam(type="text")
# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
response_format=message.response_format,
response_format=response_format,
)
if run.status != "completed":
@ -79,3 +84,15 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: handle multiple text content.
return TextMessage(content=text_content[0].text.value, source=self.name)
def save_state(self) -> Mapping[str, Any]:
return {
"description": self.description,
"assistant_id": self._assistant_id,
"thread_id": self._thread_id,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._description = state["description"]
self._assistant_id = state["assistant_id"]
self._thread_id = state["thread_id"]

View File

@ -1,13 +0,0 @@
from dataclasses import dataclass
from typing import Any, Optional
@dataclass
class ChatMessage:
"""The message type for the chat API."""
body: str
sender: str
save_message_only: bool = False
payload: Optional[Any] = None
reset: bool = False

View File

@ -4,7 +4,7 @@ from typing import Any, Sequence, Tuple
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import Reset, RespondNow, TextMessage
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
@ -140,7 +140,11 @@ Some additional points to consider:
# Update all other agents with the speaker's response.
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
self._send_message(
TextMessage(content=speaker_response.content, source=speaker_response.source), agent
TextMessage(
content=speaker_response.content,
source=speaker_response.source,
),
agent,
)
# Increment the total turns.
@ -255,7 +259,7 @@ Please output an answer in pure JSON format according to the following schema. T
self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator)
# Request a response.
step_response = await self._send_message(
RespondNow(response_format={"type": "json_object"}), self._orchestrator
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
)
# TODO: handle invalid JSON.
# TODO: use typed dictionary.
@ -293,7 +297,7 @@ Please output an answer in pure JSON format according to the following schema. T
self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator)
# Request a response.
educated_guess_response = await self._send_message(
RespondNow(response_format={"type": "json_object"}), self._orchestrator
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
)
# TODO: handle invalid JSON.
# TODO: use typed dictionary.

View File

@ -1,9 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Literal, Union
from openai.types.beta import AssistantResponseFormatParam
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Union
from agnext.agent_components.image import Image
from agnext.agent_components.types import FunctionCall
@ -44,9 +43,14 @@ class FunctionExecutionResultMessage(BaseMessage):
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
class ResponseFormat(Enum):
text = "text"
json_object = "json_object"
@dataclass
class RespondNow:
response_format: Union[Literal["none", "auto"], AssistantResponseFormatParam] = "auto"
response_format: ResponseFormat = field(default=ResponseFormat.text)
class Reset: ...