mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-30 00:30:23 +00:00
chat completion agent and state (#32)
* chat completion agent * use response format enum; use Message type for chat history; remove name from state
This commit is contained in:
parent
e3a2f79e65
commit
cd147b6eed
@ -3,14 +3,17 @@ import asyncio
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.types import SystemMessage
|
||||
from agnext.application_components import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.core._agent import Agent
|
||||
from agnext.agent_components.model_client import OpenAI
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from typing_extensions import Any, override
|
||||
|
||||
@ -95,7 +98,14 @@ async def group_chat(message: str) -> None:
|
||||
thread_id=cathy_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
|
||||
chat = GroupChat(
|
||||
"Host",
|
||||
"A round-robin chat room.",
|
||||
runtime,
|
||||
[joe, cathy],
|
||||
num_rounds=5,
|
||||
output=ConcatOutput(),
|
||||
)
|
||||
|
||||
response = runtime.send_message(TextMessage(content=message, source="host"), chat)
|
||||
|
||||
@ -105,7 +115,7 @@ async def group_chat(message: str) -> None:
|
||||
await response
|
||||
|
||||
|
||||
async def orchestrator(message: str) -> None:
|
||||
async def orchestrator_oai_assistant(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
developer_oai_assistant = openai.beta.assistants.create(
|
||||
@ -169,7 +179,63 @@ async def orchestrator(message: str) -> None:
|
||||
)
|
||||
|
||||
chat = OrchestratorChat(
|
||||
"Orchestrator Chat",
|
||||
"OrchestratorChat",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator,
|
||||
planner=planner,
|
||||
specialists=[developer, product_manager],
|
||||
)
|
||||
|
||||
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).content) # type: ignore
|
||||
|
||||
|
||||
async def orchestrator_chat_completion(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
developer = ChatCompletionAgent(
|
||||
name="Developer",
|
||||
description="A developer that writes code.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("You are a Python developer.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
product_manager = ChatCompletionAgent(
|
||||
name="ProductManager",
|
||||
description="A product manager that plans and comes up with specs.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are a product manager good at translating customer needs into software specifications.")
|
||||
],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
planner = ChatCompletionAgent(
|
||||
name="Planner",
|
||||
description="A planner that organizes and schedules tasks.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("You are a planner of complex tasks.")],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
orchestrator = ChatCompletionAgent(
|
||||
name="Orchestrator",
|
||||
description="An orchestrator that coordinates the team.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
chat = OrchestratorChat(
|
||||
"OrchestratorChat",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator,
|
||||
@ -187,18 +253,16 @@ async def orchestrator(message: str) -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a pattern demo.")
|
||||
chocies = ["group_chat", "orchestrator"]
|
||||
choices = {
|
||||
"group_chat": group_chat,
|
||||
"orchestrator_oai_assistant": orchestrator_oai_assistant,
|
||||
"orchestrator_chat_completion": orchestrator_chat_completion,
|
||||
}
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
choices=chocies,
|
||||
choices=list(choices.keys()),
|
||||
help="The pattern to demo.",
|
||||
)
|
||||
parser.add_argument("--message", help="The message to send.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pattern == "group_chat":
|
||||
asyncio.run(group_chat(args.message))
|
||||
elif args.pattern == "orchestrator":
|
||||
asyncio.run(orchestrator(args.message))
|
||||
else:
|
||||
raise ValueError(f"Invalid pattern: {args.pattern}")
|
||||
asyncio.run(choices[args.pattern](args.message))
|
||||
|
||||
63
src/agnext/chat/agents/chat_completion_agent.py
Normal file
63
src/agnext/chat/agents/chat_completion_agent.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Any, Callable, Dict, List, Mapping
|
||||
|
||||
from agnext.agent_components.model_client import ModelClient
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.agent_components.types import SystemMessage
|
||||
from agnext.chat.agents.base import BaseChatAgent
|
||||
from agnext.chat.types import Message, Reset, RespondNow, ResponseFormat, TextMessage
|
||||
from agnext.chat.utils import convert_messages_to_llm_messages
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
|
||||
|
||||
class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ModelClient,
|
||||
tools: Dict[str, Callable[..., str]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._system_messages = system_messages
|
||||
self._client = model_client
|
||||
self._tools = tools or {}
|
||||
self._chat_messages: List[Message] = []
|
||||
|
||||
@message_handler(TextMessage)
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
# Add a user message.
|
||||
self._chat_messages.append(message)
|
||||
|
||||
@message_handler(Reset)
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
# Reset the chat messages.
|
||||
self._chat_messages = []
|
||||
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
if message.response_format == ResponseFormat.json_object:
|
||||
response_format = {"type": "json_object"}
|
||||
else:
|
||||
response_format = {"type": "text"}
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
|
||||
extra_create_args={"response_format": response_format},
|
||||
)
|
||||
if isinstance(response.content, str):
|
||||
return TextMessage(content=response.content, source=self.name)
|
||||
else:
|
||||
raise ValueError(f"Unexpected response: {response.content}")
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"description": self.description,
|
||||
"chat_messages": self._chat_messages,
|
||||
"system_messages": self._system_messages,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._chat_messages = state["chat_messages"]
|
||||
self._system_messages = state["system_messages"]
|
||||
self._description = state["description"]
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Mapping
|
||||
|
||||
import openai
|
||||
from openai.types.beta import AssistantResponseFormatParam
|
||||
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.chat.agents.base import BaseChatAgent
|
||||
from agnext.chat.types import Reset, RespondNow, TextMessage
|
||||
from agnext.chat.types import Reset, RespondNow, ResponseFormat, TextMessage
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
|
||||
|
||||
@ -56,12 +57,16 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
# Handle response format.
|
||||
if message.response_format == ResponseFormat.json_object:
|
||||
response_format = AssistantResponseFormatParam(type="json_object")
|
||||
else:
|
||||
response_format = AssistantResponseFormatParam(type="text")
|
||||
|
||||
# Create a run and wait until it finishes.
|
||||
run = await self._client.beta.threads.runs.create_and_poll(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant_id,
|
||||
response_format=message.response_format,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
if run.status != "completed":
|
||||
@ -79,3 +84,15 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
|
||||
# TODO: handle multiple text content.
|
||||
return TextMessage(content=text_content[0].text.value, source=self.name)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"description": self.description,
|
||||
"assistant_id": self._assistant_id,
|
||||
"thread_id": self._thread_id,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._description = state["description"]
|
||||
self._assistant_id = state["assistant_id"]
|
||||
self._thread_id = state["thread_id"]
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""The message type for the chat API."""
|
||||
|
||||
body: str
|
||||
sender: str
|
||||
save_message_only: bool = False
|
||||
payload: Optional[Any] = None
|
||||
reset: bool = False
|
||||
@ -4,7 +4,7 @@ from typing import Any, Sequence, Tuple
|
||||
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ...core import AgentRuntime, CancellationToken
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..types import Reset, RespondNow, TextMessage
|
||||
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
|
||||
|
||||
|
||||
class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
|
||||
@ -140,7 +140,11 @@ Some additional points to consider:
|
||||
# Update all other agents with the speaker's response.
|
||||
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
|
||||
self._send_message(
|
||||
TextMessage(content=speaker_response.content, source=speaker_response.source), agent
|
||||
TextMessage(
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Increment the total turns.
|
||||
@ -255,7 +259,7 @@ Please output an answer in pure JSON format according to the following schema. T
|
||||
self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
step_response = await self._send_message(
|
||||
RespondNow(response_format={"type": "json_object"}), self._orchestrator
|
||||
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
|
||||
)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
@ -293,7 +297,7 @@ Please output an answer in pure JSON format according to the following schema. T
|
||||
self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
educated_guess_response = await self._send_message(
|
||||
RespondNow(response_format={"type": "json_object"}), self._orchestrator
|
||||
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
|
||||
)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from openai.types.beta import AssistantResponseFormatParam
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
|
||||
from agnext.agent_components.image import Image
|
||||
from agnext.agent_components.types import FunctionCall
|
||||
@ -44,9 +43,14 @@ class FunctionExecutionResultMessage(BaseMessage):
|
||||
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
class ResponseFormat(Enum):
|
||||
text = "text"
|
||||
json_object = "json_object"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RespondNow:
|
||||
response_format: Union[Literal["none", "auto"], AssistantResponseFormatParam] = "auto"
|
||||
response_format: ResponseFormat = field(default=ResponseFormat.text)
|
||||
|
||||
|
||||
class Reset: ...
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user