mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 10:50:03 +00:00
First draft for chat layer. (#10)
This commit is contained in:
parent
1a9dddbcda
commit
d77390dc07
132
examples/patterns.py
Normal file
132
examples/patterns.py
Normal file
@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.models_clients.openai_client import OpenAI
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.messages import ChatMessage
|
||||
from agnext.chat.patterns.group_chat import GroupChat
|
||||
from agnext.chat.patterns.orchestrator import Orchestrator
|
||||
from agnext.chat.runtimes import SingleThreadedRuntime
|
||||
|
||||
|
||||
async def group_chat() -> None:
|
||||
runtime = SingleThreadedRuntime()
|
||||
|
||||
joe_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="Joe",
|
||||
instructions="You are a commedian named Joe. Make the audience laugh.",
|
||||
)
|
||||
joe_oai_thread = openai.beta.threads.create()
|
||||
joe = OpenAIAssistantAgent(
|
||||
name="Joe",
|
||||
description="Joe the commedian.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=joe_oai_assistant.id,
|
||||
thread_id=joe_oai_thread.id,
|
||||
)
|
||||
|
||||
cathy_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="Cathy",
|
||||
instructions="You are a poet named Cathy. Write beautiful poems.",
|
||||
)
|
||||
cathy_oai_thread = openai.beta.threads.create()
|
||||
cathy = OpenAIAssistantAgent(
|
||||
name="Cathy",
|
||||
description="Cathy the poet.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=cathy_oai_assistant.id,
|
||||
thread_id=cathy_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = GroupChat(
|
||||
"chat_room",
|
||||
"A round-robin chat room.",
|
||||
runtime,
|
||||
[joe, cathy],
|
||||
num_rounds=5,
|
||||
)
|
||||
|
||||
response = runtime.send_message(ChatMessage(body="Run a show!", sender="external"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).body)
|
||||
|
||||
|
||||
async def orchestrator() -> None:
|
||||
runtime = SingleThreadedRuntime()
|
||||
|
||||
developer_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="Developer",
|
||||
instructions="You are a Python developer.",
|
||||
)
|
||||
developer_oai_thread = openai.beta.threads.create()
|
||||
developer = OpenAIAssistantAgent(
|
||||
name="Developer",
|
||||
description="A developer that writes code.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=developer_oai_assistant.id,
|
||||
thread_id=developer_oai_thread.id,
|
||||
)
|
||||
|
||||
product_manager_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="ProductManager",
|
||||
instructions="You are a product manager good at translating customer needs into software specifications.",
|
||||
)
|
||||
product_manager_oai_thread = openai.beta.threads.create()
|
||||
product_manager = OpenAIAssistantAgent(
|
||||
name="ProductManager",
|
||||
description="A product manager that plans and comes up with specs.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=product_manager_oai_assistant.id,
|
||||
thread_id=product_manager_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = Orchestrator(
|
||||
"Team",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
[developer, product_manager],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
response = runtime.send_message(
|
||||
ChatMessage(
|
||||
body="Write a simple FastAPI webapp for showing the current time.",
|
||||
sender="customer",
|
||||
),
|
||||
chat,
|
||||
)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).body)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a pattern demo.")
|
||||
chocies = ["group_chat", "orchestrator"]
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
choices=chocies,
|
||||
help="The pattern to demo.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pattern == "group_chat":
|
||||
asyncio.run(group_chat())
|
||||
elif args.pattern == "orchestrator":
|
||||
asyncio.run(orchestrator())
|
||||
else:
|
||||
raise ValueError(f"Invalid pattern: {args.pattern}")
|
||||
@ -1,123 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
|
||||
from agnext.core.agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
# TODO: a runtime should be able to handle multiple types of messages
|
||||
# TODO: allow request and response to be different message types
|
||||
# should support this in handlers.
|
||||
@dataclass
|
||||
class GroupChatMessage:
|
||||
body: str
|
||||
sender: str
|
||||
require_response: bool
|
||||
|
||||
|
||||
class GroupChatParticipant(TypeRoutedAgent[GroupChatMessage]):
|
||||
def __init__(self, name: str, runtime: AgentRuntime[GroupChatMessage]) -> None:
|
||||
super().__init__(name, runtime)
|
||||
|
||||
@message_handler(GroupChatMessage)
|
||||
async def on_new_message(self, message: GroupChatMessage) -> GroupChatMessage:
|
||||
print(f"{self.name} received message from {message.sender}: {message.body}")
|
||||
if not message.require_response:
|
||||
return GroupChatMessage(body="OK", sender=self.name, require_response=False)
|
||||
# Generate a random response.
|
||||
response_body = random.choice(
|
||||
[
|
||||
"Hello!",
|
||||
"Hi!",
|
||||
"Hey!",
|
||||
"How are you?",
|
||||
"What's up?",
|
||||
"Good day!",
|
||||
"Good morning!",
|
||||
"Good evening!",
|
||||
"Good afternoon!",
|
||||
"Good night!",
|
||||
"Good bye!",
|
||||
"Bye!",
|
||||
"See you later!",
|
||||
"See you soon!",
|
||||
"See you!",
|
||||
]
|
||||
)
|
||||
return GroupChatMessage(body=response_body, sender=self.name, require_response=False)
|
||||
|
||||
|
||||
class RoundRobinChat(TypeRoutedAgent[GroupChatMessage]):
|
||||
def __init__(
|
||||
self, name: str, runtime: AgentRuntime[GroupChatMessage], agents: List[GroupChatParticipant], num_rounds: int
|
||||
) -> None:
|
||||
super().__init__(name, runtime)
|
||||
self._agents = agents
|
||||
self._num_rounds = num_rounds
|
||||
|
||||
@message_handler(GroupChatMessage)
|
||||
async def on_new_message(self, message: GroupChatMessage) -> GroupChatMessage:
|
||||
print(f"{self.name} received task request from {message.sender}: {message.body}")
|
||||
|
||||
history = [message]
|
||||
previous_speaker: TypeRoutedAgent[GroupChatMessage] | None = None
|
||||
round = 0
|
||||
|
||||
while round < self._num_rounds:
|
||||
# 1. Select speaker.
|
||||
speaker = self._agents[round % len(self._agents)]
|
||||
|
||||
# 2. Send the last message to non-speaking agents.
|
||||
for agent in self._agents:
|
||||
if agent is not previous_speaker and agent is not speaker:
|
||||
# TODO: should support a separate message type for just passing on a message.
|
||||
_ = await self._send_message(
|
||||
GroupChatMessage(body=history[-1].body, sender=history[-1].sender, require_response=False),
|
||||
agent,
|
||||
)
|
||||
|
||||
# 3. Send the last message to the speaking agent and ask to speak.
|
||||
if previous_speaker is not speaker:
|
||||
response = await self._send_message(
|
||||
GroupChatMessage(body=history[-1].body, sender=history[-1].sender, require_response=True), speaker
|
||||
)
|
||||
else:
|
||||
# The same speaker is speaking again.
|
||||
# TODO: should support a separate message type for request to speak only.
|
||||
response = await self._send_message(
|
||||
GroupChatMessage(body="", sender=self.name, require_response=True), speaker
|
||||
)
|
||||
print(f"Speaker {speaker.name} responded with: {response.body}")
|
||||
|
||||
# 4. Append the response to the history.
|
||||
history.append(response)
|
||||
|
||||
# 5. Update the previous speaker.
|
||||
previous_speaker = speaker
|
||||
|
||||
# 6. Increment the round.
|
||||
round += 1
|
||||
|
||||
# Construct the final response.
|
||||
response_body = "\n".join([f"{message.sender}: {message.body}" for message in history])
|
||||
return GroupChatMessage(body=response_body, sender=self.name, require_response=False)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime[GroupChatMessage]()
|
||||
participants = [GroupChatParticipant(f"participant_{i}", runtime) for i in range(3)]
|
||||
chat = RoundRobinChat("chat_room", runtime, participants, num_rounds=10)
|
||||
|
||||
response = runtime.send_message(GroupChatMessage(body="Hello!", sender="external", require_response=True), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).body)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -2,7 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import Mapping, Optional, Sequence, runtime_checkable
|
||||
|
||||
from typing_extensions import Any, AsyncGenerator, List, Protocol, Required, TypedDict, Union
|
||||
from typing_extensions import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Protocol,
|
||||
Required,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .types import CreateResult, FunctionDefinition, LLMMessage, RequestUsage
|
||||
|
||||
@ -18,7 +25,7 @@ class ModelClient(Protocol):
|
||||
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
|
||||
async def create(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
@ -28,7 +35,7 @@ class ModelClient(Protocol):
|
||||
|
||||
def create_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
|
||||
@ -154,13 +154,17 @@ def func_call_to_oai(message: FunctionCall) -> ChatCompletionMessageToolCallPara
|
||||
)
|
||||
|
||||
|
||||
def tool_message_to_oai(message: FunctionExecutionResultMessage) -> Sequence[ChatCompletionToolMessageParam]:
|
||||
def tool_message_to_oai(
|
||||
message: FunctionExecutionResultMessage,
|
||||
) -> Sequence[ChatCompletionToolMessageParam]:
|
||||
return [
|
||||
ChatCompletionToolMessageParam(content=x.content, role="tool", tool_call_id=x.call_id) for x in message.content
|
||||
]
|
||||
|
||||
|
||||
def assistant_message_to_oai(message: AssistantMessage) -> ChatCompletionAssistantMessageParam:
|
||||
def assistant_message_to_oai(
|
||||
message: AssistantMessage,
|
||||
) -> ChatCompletionAssistantMessageParam:
|
||||
if isinstance(message.content, list):
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
tool_calls=[func_call_to_oai(x) for x in message.content],
|
||||
@ -240,7 +244,9 @@ class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False)
|
||||
model_capabilities: Required[ModelCapabilities]
|
||||
|
||||
|
||||
def convert_functions(functions: Sequence[FunctionDefinition]) -> List[ChatCompletionToolParam]:
|
||||
def convert_functions(
|
||||
functions: Sequence[FunctionDefinition],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for func in functions:
|
||||
result.append(
|
||||
@ -292,7 +298,7 @@ class BaseOpenAI(ModelClient):
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
@ -343,7 +349,7 @@ class BaseOpenAI(ModelClient):
|
||||
usage = RequestUsage(
|
||||
# TODO backup token counting
|
||||
prompt_tokens=result.usage.prompt_tokens if result.usage is not None else 0,
|
||||
completion_tokens=result.usage.completion_tokens if result.usage is not None else 0,
|
||||
completion_tokens=(result.usage.completion_tokens if result.usage is not None else 0),
|
||||
)
|
||||
|
||||
if self._resolved_model is not None:
|
||||
@ -383,7 +389,7 @@ class BaseOpenAI(ModelClient):
|
||||
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: List[LLMMessage],
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
|
||||
@ -9,7 +9,9 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
def message_handler(target_type: Type[T]) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
|
||||
def message_handler(
|
||||
target_type: Type[T],
|
||||
) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
|
||||
def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
|
||||
func._target_type = target_type # type: ignore
|
||||
return func
|
||||
@ -26,7 +28,7 @@ class TypeRoutedAgent(BaseAgent[T]):
|
||||
router.add_agent(self)
|
||||
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr)):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
if hasattr(handler, "_target_type"):
|
||||
# TODO do i need to partially apply self?
|
||||
|
||||
0
src/agnext/chat/__init__.py
Normal file
0
src/agnext/chat/__init__.py
Normal file
0
src/agnext/chat/agents/__init__.py
Normal file
0
src/agnext/chat/agents/__init__.py
Normal file
29
src/agnext/chat/agents/base.py
Normal file
29
src/agnext/chat/agents/base.py
Normal file
@ -0,0 +1,29 @@
|
||||
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ...core.cancellation_token import CancellationToken
|
||||
from ..messages import ChatMessage
|
||||
from ..runtimes import SingleThreadedRuntime
|
||||
|
||||
|
||||
class BaseChatAgent(TypeRoutedAgent[ChatMessage]):
|
||||
"""The BaseAgent class for the chat API."""
|
||||
|
||||
def __init__(self, name: str, description: str, runtime: SingleThreadedRuntime) -> None:
|
||||
super().__init__(name, runtime)
|
||||
self._description = description
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""The description of the agent."""
|
||||
return self._description
|
||||
|
||||
async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
|
||||
"""The method to handle chat messages."""
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO: how should we expose cancellation in chat layer?
|
||||
@message_handler(ChatMessage)
|
||||
async def on_chat_message_with_cancellation(
|
||||
self, message: ChatMessage, cancellation_token: CancellationToken
|
||||
) -> ChatMessage:
|
||||
"""The method to handle chat messages with cancellation."""
|
||||
return await self.on_chat_message(message)
|
||||
69
src/agnext/chat/agents/oai_assistant.py
Normal file
69
src/agnext/chat/agents/oai_assistant.py
Normal file
@ -0,0 +1,69 @@
|
||||
import openai
|
||||
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
from ..runtimes import SingleThreadedRuntime
|
||||
|
||||
|
||||
class OpenAIAssistantAgent(BaseChatAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
runtime: SingleThreadedRuntime,
|
||||
client: openai.AsyncClient,
|
||||
assistant_id: str,
|
||||
thread_id: str,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._client = client
|
||||
self._assistant_id = assistant_id
|
||||
self._thread_id = thread_id
|
||||
self._current_session_window_length = 0
|
||||
|
||||
async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
|
||||
print("---------------")
|
||||
print(f"{self.name} received message from {message.sender}: {message.body}")
|
||||
print("---------------")
|
||||
if message.reset:
|
||||
# Reset the current session window.
|
||||
self._current_session_window_length = 0
|
||||
|
||||
# Save the message to the thread.
|
||||
_ = await self._client.beta.threads.messages.create(
|
||||
thread_id=self._thread_id,
|
||||
content=message.body,
|
||||
role="user",
|
||||
metadata={"sender": message.sender},
|
||||
)
|
||||
self._current_session_window_length += 1
|
||||
|
||||
# If the message is a save_message_only message, return early.
|
||||
if message.save_message_only:
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
|
||||
# Create a run and wait until it finishes.
|
||||
run = await self._client.beta.threads.runs.create_and_poll(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant_id,
|
||||
truncation_strategy={
|
||||
"type": "last_messages",
|
||||
"last_messages": self._current_session_window_length,
|
||||
},
|
||||
)
|
||||
|
||||
if run.status != "completed":
|
||||
# TODO: handle other statuses.
|
||||
raise ValueError(f"Run did not complete successfully: {run}")
|
||||
|
||||
# Get the last message from the run.
|
||||
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
|
||||
last_message_content = response.data[0].content
|
||||
|
||||
# TODO: handle array of content.
|
||||
text_content = [content for content in last_message_content if content.type == "text"]
|
||||
if not text_content:
|
||||
raise ValueError(f"Expected text content in the last message: {last_message_content}")
|
||||
|
||||
# TODO: handle multiple text content.
|
||||
return ChatMessage(body=text_content[0].text.value, sender=self.name)
|
||||
32
src/agnext/chat/agents/random_agent.py
Normal file
32
src/agnext/chat/agents/random_agent.py
Normal file
@ -0,0 +1,32 @@
|
||||
import random
|
||||
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
|
||||
|
||||
class RandomResponseAgent(BaseChatAgent):
|
||||
async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
|
||||
print(f"{self.name} received message from {message.sender}: {message.body}")
|
||||
if message.save_message_only:
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
# Generate a random response.
|
||||
response_body = random.choice(
|
||||
[
|
||||
"Hello!",
|
||||
"Hi!",
|
||||
"Hey!",
|
||||
"How are you?",
|
||||
"What's up?",
|
||||
"Good day!",
|
||||
"Good morning!",
|
||||
"Good evening!",
|
||||
"Good afternoon!",
|
||||
"Good night!",
|
||||
"Good bye!",
|
||||
"Bye!",
|
||||
"See you later!",
|
||||
"See you soon!",
|
||||
"See you!",
|
||||
]
|
||||
)
|
||||
return ChatMessage(body=response_body, sender=self.name)
|
||||
13
src/agnext/chat/messages.py
Normal file
13
src/agnext/chat/messages.py
Normal file
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""The message type for the chat API."""
|
||||
|
||||
body: str
|
||||
sender: str
|
||||
save_message_only: bool = False
|
||||
payload: Optional[Any] = None
|
||||
reset: bool = False
|
||||
0
src/agnext/chat/patterns/__init__.py
Normal file
0
src/agnext/chat/patterns/__init__.py
Normal file
75
src/agnext/chat/patterns/group_chat.py
Normal file
75
src/agnext/chat/patterns/group_chat.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
from ..runtimes import SingleThreadedRuntime
|
||||
|
||||
|
||||
class GroupChat(BaseChatAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
runtime: SingleThreadedRuntime,
|
||||
agents: Sequence[BaseChatAgent],
|
||||
num_rounds: int,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._agents = agents
|
||||
self._num_rounds = num_rounds
|
||||
self._history: List[ChatMessage] = []
|
||||
|
||||
async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
|
||||
if message.reset:
|
||||
# Reset the history.
|
||||
self._history = []
|
||||
if message.save_message_only:
|
||||
# TODO: what should we do with save_message_only messages for this pattern?
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
|
||||
self._history.append(message)
|
||||
previous_speaker: BaseChatAgent | None = None
|
||||
round = 0
|
||||
|
||||
while round < self._num_rounds:
|
||||
# TODO: add support for advanced speaker selection.
|
||||
# Select speaker (round-robin for now).
|
||||
speaker = self._agents[round % len(self._agents)]
|
||||
|
||||
# Send the last message to non-speaking agents.
|
||||
for agent in [agent for agent in self._agents if agent is not previous_speaker and agent is not speaker]:
|
||||
_ = await self._send_message(
|
||||
ChatMessage(
|
||||
body=self._history[-1].body,
|
||||
sender=self._history[-1].sender,
|
||||
save_message_only=True,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Send the last message to the speaking agent and ask to speak.
|
||||
if previous_speaker is not speaker:
|
||||
response = await self._send_message(
|
||||
ChatMessage(body=self._history[-1].body, sender=self._history[-1].sender),
|
||||
speaker,
|
||||
)
|
||||
else:
|
||||
# The same speaker is speaking again.
|
||||
# TODO: should support a separate message type for request to speak only.
|
||||
response = await self._send_message(
|
||||
ChatMessage(body="", sender=self.name),
|
||||
speaker,
|
||||
)
|
||||
|
||||
# 4. Append the response to the history.
|
||||
self._history.append(response)
|
||||
|
||||
# 5. Update the previous speaker.
|
||||
previous_speaker = speaker
|
||||
|
||||
# 6. Increment the round.
|
||||
round += 1
|
||||
|
||||
# Construct the final response.
|
||||
response_body = "\n".join([f"{message.sender}: {message.body}" for message in self._history])
|
||||
return ChatMessage(body=response_body, sender=self.name)
|
||||
396
src/agnext/chat/patterns/orchestrator.py
Normal file
396
src/agnext/chat/patterns/orchestrator.py
Normal file
@ -0,0 +1,396 @@
|
||||
import json
|
||||
from typing import Any, List, Sequence, Tuple
|
||||
|
||||
from ...agent_components.model_client import ModelClient
|
||||
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
from ..runtimes import SingleThreadedRuntime
|
||||
|
||||
|
||||
class Orchestrator(BaseChatAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
runtime: SingleThreadedRuntime,
|
||||
agents: Sequence[BaseChatAgent],
|
||||
model_client: ModelClient,
|
||||
max_turns: int = 30,
|
||||
max_stalled_turns_before_retry: int = 2,
|
||||
max_retry_attempts: int = 1,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._agents = agents
|
||||
self._model_client = model_client
|
||||
self._max_turns = max_turns
|
||||
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
|
||||
self._max_retry_attempts_before_educated_guess = max_retry_attempts
|
||||
self._history: List[ChatMessage] = []
|
||||
|
||||
async def on_chat_message(self, message: ChatMessage) -> ChatMessage:
|
||||
# A task is received.
|
||||
task = message.body
|
||||
|
||||
if message.reset:
|
||||
# Reset the history.
|
||||
self._history = []
|
||||
if message.save_message_only:
|
||||
# TODO: what should we do with save_message_only messages for this pattern?
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
|
||||
# Prepare the task.
|
||||
team, names, facts, plan = await self._prepare_task(task, message.sender)
|
||||
|
||||
# Main loop.
|
||||
total_turns = 0
|
||||
retry_attempts = 0
|
||||
ledgers: List[List[LLMMessage]] = []
|
||||
while total_turns < self._max_turns:
|
||||
# Create the task specs.
|
||||
task_specs = f"""
|
||||
We are working to address the following user request:
|
||||
|
||||
{task}
|
||||
|
||||
|
||||
To answer this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Some additional points to consider:
|
||||
|
||||
{facts}
|
||||
|
||||
{plan}
|
||||
""".strip()
|
||||
|
||||
# Send the task specs to the team and signal a reset.
|
||||
for agent in self._agents:
|
||||
self._send_message(
|
||||
ChatMessage(
|
||||
body=task_specs,
|
||||
sender=self.name,
|
||||
save_message_only=True,
|
||||
reset=True,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Create the ledger.
|
||||
ledger: List[LLMMessage] = [
|
||||
AssistantMessage(
|
||||
content=task_specs,
|
||||
source=self.name,
|
||||
)
|
||||
]
|
||||
ledgers.append(ledger)
|
||||
|
||||
# Inner loop.
|
||||
stalled_turns = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reflect on the task.
|
||||
data = await self._reflect_on_task(task, team, names, ledger, message.sender)
|
||||
|
||||
# Check if the request is satisfied.
|
||||
if data["is_request_satisfied"]["answer"]:
|
||||
return ChatMessage(
|
||||
body="The task has been successfully addressed.",
|
||||
sender=self.name,
|
||||
payload={
|
||||
"ledgers": ledgers,
|
||||
"status": "success",
|
||||
"reason": data["is_request_satisfied"]["reason"],
|
||||
},
|
||||
)
|
||||
|
||||
# Update stalled turns.
|
||||
if data["is_progress_being_made"]["answer"]:
|
||||
stalled_turns = max(0, stalled_turns - 1)
|
||||
else:
|
||||
stalled_turns += 1
|
||||
|
||||
# Handle retry.
|
||||
if stalled_turns > self._max_stalled_turns_before_retry:
|
||||
# In a retry, we need to rewrite the facts and the plan.
|
||||
|
||||
# Rewrite the facts.
|
||||
facts = await self._rewrite_facts(facts, ledger, message.sender)
|
||||
|
||||
# Increment the retry attempts.
|
||||
retry_attempts += 1
|
||||
|
||||
# Check if we should just guess.
|
||||
if retry_attempts > self._max_retry_attempts_before_educated_guess:
|
||||
# Make an educated guess.
|
||||
educated_guess = await self._educated_guess(facts, ledger, message.sender)
|
||||
if educated_guess["has_educated_guesses"]["answer"]:
|
||||
return ChatMessage(
|
||||
body="The task is addressed with an educated guess.",
|
||||
sender=self.name,
|
||||
payload={
|
||||
"ledgers": ledgers,
|
||||
"status": "educated_guess",
|
||||
"reason": educated_guess["has_educated_guesses"]["reason"],
|
||||
},
|
||||
)
|
||||
|
||||
# Come up with a new plan.
|
||||
plan = await self._rewrite_plan(team, ledger, message.sender)
|
||||
|
||||
# Exit the inner loop.
|
||||
break
|
||||
|
||||
# Get the subtask.
|
||||
subtask = data["instruction_or_question"]["answer"]
|
||||
if subtask is None:
|
||||
subtask = ""
|
||||
|
||||
# Find the speaker.
|
||||
try:
|
||||
speaker = next(agent for agent in self._agents if agent.name == data["next_speaker"]["answer"])
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||
|
||||
# Update all other agents.
|
||||
for agent in [agent for agent in self._agents if agent != speaker]:
|
||||
_ = await self._send_message(
|
||||
ChatMessage(
|
||||
body=subtask,
|
||||
sender=self.name,
|
||||
save_message_only=True,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Update the speaker and ask to speak.
|
||||
speaker_response = await self._send_message(
|
||||
ChatMessage(body=subtask, sender=self.name),
|
||||
speaker,
|
||||
)
|
||||
|
||||
# Update the ledger.
|
||||
ledger.append(
|
||||
AssistantMessage(
|
||||
content=subtask,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
# Update all other agents with the speaker's response.
|
||||
for agent in [agent for agent in self._agents if agent != speaker]:
|
||||
_ = await self._send_message(
|
||||
ChatMessage(
|
||||
body=speaker_response.body,
|
||||
sender=speaker_response.sender,
|
||||
save_message_only=True,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Update the ledger.
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=speaker_response.body,
|
||||
source=speaker_response.sender,
|
||||
)
|
||||
)
|
||||
|
||||
# Increment the total turns.
|
||||
total_turns += 1
|
||||
|
||||
return ChatMessage(
|
||||
body="The task was not addressed",
|
||||
sender=self.name,
|
||||
payload={
|
||||
"ledgers": ledgers,
|
||||
"status": "failure",
|
||||
"reason": "The maximum number of turns was reached.",
|
||||
},
|
||||
)
|
||||
|
||||
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
|
||||
# A reusable description of the team.
|
||||
team = "\n".join([agent.name + ": " + agent.description for agent in self._agents])
|
||||
names = ", ".join([agent.name for agent in self._agents])
|
||||
|
||||
# A place to store relevant facts.
|
||||
facts = ""
|
||||
|
||||
# A plance to store the plan.
|
||||
plan = ""
|
||||
|
||||
# Start by writing what we know
|
||||
closed_book_prompt = f"""Below I will present you a request. Before we begin addressing the request, please answer the following pre-survey to the best of your ability. Keep in mind that you are Ken Jennings-level with trivia, and Mensa-level with puzzles, so there should be a deep well to draw from.
|
||||
|
||||
Here is the request:
|
||||
|
||||
{task}
|
||||
|
||||
Here is the pre-survey:
|
||||
|
||||
1. Please list any specific facts or figures that are GIVEN in the request itself. It is possible that there are none.
|
||||
2. Please list any facts that may need to be looked up, and WHERE SPECIFICALLY they might be found. In some cases, authoritative sources are mentioned in the request itself.
|
||||
3. Please list any facts that may need to be derived (e.g., via logical deduction, simulation, or computation)
|
||||
4. Please list any facts that are recalled from memory, hunches, well-reasoned guesses, etc.
|
||||
|
||||
When answering this survey, keep in mind that "facts" will typically be specific names, dates, statistics, etc. Your answer should use headings:
|
||||
|
||||
1. GIVEN OR VERIFIED FACTS
|
||||
2. FACTS TO LOOK UP
|
||||
3. FACTS TO DERIVE
|
||||
4. EDUCATED GUESSES
|
||||
""".strip()
|
||||
|
||||
starter_messages: List[LLMMessage] = [
|
||||
UserMessage(
|
||||
content=closed_book_prompt,
|
||||
source=sender,
|
||||
)
|
||||
]
|
||||
facts_response = await self._model_client.create(messages=starter_messages)
|
||||
starter_messages.append(
|
||||
AssistantMessage(
|
||||
content=facts_response.content,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
facts = str(facts_response.content)
|
||||
|
||||
# Make an initial plan
|
||||
plan_prompt = f"""Fantastic. To address this request we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip()
|
||||
starter_messages.append(
|
||||
UserMessage(
|
||||
content=plan_prompt,
|
||||
source=sender,
|
||||
)
|
||||
)
|
||||
plan_response = await self._model_client.create(messages=starter_messages)
|
||||
starter_messages.append(
|
||||
AssistantMessage(
|
||||
content=plan_response.content,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
plan = str(plan_response.content)
|
||||
|
||||
return team, names, facts, plan
|
||||
|
||||
async def _reflect_on_task(
|
||||
self,
|
||||
task: str,
|
||||
team: str,
|
||||
names: str,
|
||||
ledger: List[LLMMessage],
|
||||
sender: str,
|
||||
) -> Any:
|
||||
step_prompt = f"""
|
||||
Recall we are working on the following request:
|
||||
|
||||
{task}
|
||||
|
||||
And we have assembled the following team:
|
||||
|
||||
{team}
|
||||
|
||||
To make progress on the request, please answer the following questions, including necessary reasoning:
|
||||
|
||||
- Is the request fully satisfied? (True if complete, or False if the original request has yet to be SUCCESSFULLY addressed)
|
||||
- Are we making forward progress? (True if just starting, or recent messages are adding value. False if recent messages show evidence of being stuck in a reasoning or action loop, or there is evidence of significant barriers to success such as the inability to read from a required file)
|
||||
- Who should speak next? (select from: {names})
|
||||
- What instruction or question would you give this team member? (Phrase as if speaking directly to them, and include any specific information they may need)
|
||||
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"is_request_satisfied": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"is_progress_being_made": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}},
|
||||
"next_speaker": {{
|
||||
"reason": string,
|
||||
"answer": string (select from: {names})
|
||||
}},
|
||||
"instruction_or_question": {{
|
||||
"reason": string,
|
||||
"answer": string
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
step_response = await self._model_client.create(
|
||||
messages=ledger + [UserMessage(content=step_prompt, source=sender)],
|
||||
extra_create_args={"response_format": {"type": "json_object"}},
|
||||
)
|
||||
step_response_json = str(step_response.content)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
return json.loads(step_response_json)
|
||||
|
||||
async def _rewrite_facts(self, facts: str, ledger: List[LLMMessage], sender: str) -> str:
|
||||
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
|
||||
|
||||
{facts}
|
||||
""".strip()
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=new_facts_prompt,
|
||||
source=sender,
|
||||
)
|
||||
)
|
||||
new_facts_response = await self._model_client.create(messages=ledger)
|
||||
facts = str(new_facts_response.content)
|
||||
ledger.append(
|
||||
AssistantMessage(
|
||||
content=facts,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
return facts
|
||||
|
||||
async def _educated_guess(self, facts: str, ledger: List[LLMMessage], sender: str) -> Any:
|
||||
# Make an educated guess.
|
||||
educated_guess_promt = f"""Given the following information
|
||||
|
||||
{facts}
|
||||
|
||||
Please answer the following question, including necessary reasoning:
|
||||
- Do you have two or more congruent pieces of information that will allow you to make an educated guess for the original request? The educated guess MUST answer the question.
|
||||
Please output an answer in pure JSON format according to the following schema. The JSON object must be parsable as-is. DO NOT OUTPUT ANYTHING OTHER THAN JSON, AND DO NOT DEVIATE FROM THIS SCHEMA:
|
||||
|
||||
{{
|
||||
"has_educated_guesses": {{
|
||||
"reason": string,
|
||||
"answer": boolean
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
educated_guess_response = await self._model_client.create(
|
||||
messages=ledger + [UserMessage(content=educated_guess_promt, source=sender)],
|
||||
extra_create_args={"response_format": {"type": "json_object"}},
|
||||
)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
return json.loads(str(educated_guess_response.content))
|
||||
|
||||
async def _rewrite_plan(self, team: str, ledger: List[LLMMessage], sender: str) -> str:
|
||||
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
|
||||
|
||||
Team membership:
|
||||
{team}
|
||||
""".strip()
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=new_plan_prompt,
|
||||
source=sender,
|
||||
)
|
||||
)
|
||||
new_plan_response = await self._model_client.create(messages=ledger)
|
||||
return str(new_plan_response.content)
|
||||
12
src/agnext/chat/runtimes.py
Normal file
12
src/agnext/chat/runtimes.py
Normal file
@ -0,0 +1,12 @@
|
||||
from ..application_components.single_threaded_agent_runtime import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from .messages import ChatMessage
|
||||
|
||||
|
||||
# The built-in runtime for the chat API.
|
||||
class SingleThreadedRuntime(SingleThreadedAgentRuntime[ChatMessage]):
|
||||
pass
|
||||
|
||||
|
||||
# Each new built-in runtime should be able to handle ChatMessage type.
|
||||
Loading…
x
Reference in New Issue
Block a user