mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-25 16:16:37 +00:00
oai assistant and pattern fixes (#30)
This commit is contained in:
parent
f4a5835772
commit
ecbc3b7806
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.model_client import OpenAI
|
||||
@ -8,13 +8,18 @@ from agnext.application_components import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.messages import ChatMessage
|
||||
from agnext.chat.patterns.group_chat import GroupChat, Output
|
||||
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||
from agnext.chat.patterns.orchestrator import Orchestrator
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.core._agent import Agent
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from typing_extensions import Any, override
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class ConcatOutput(Output):
|
||||
class ConcatOutput(GroupChatOutput):
|
||||
def __init__(self) -> None:
|
||||
self._output = ""
|
||||
|
||||
@ -32,8 +37,26 @@ class ConcatOutput(Output):
|
||||
self._output = ""
|
||||
|
||||
|
||||
class LoggingHandler(DefaultInterventionHandler):
|
||||
@override
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
|
||||
if sender is None:
|
||||
print(f"Sending message to {recipient.name}: {message}")
|
||||
else:
|
||||
print(f"Sending message from {sender.name} to {recipient.name}: {message}")
|
||||
return message
|
||||
|
||||
@override
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
|
||||
if recipient is None:
|
||||
print(f"Received response from {sender.name}: {message}")
|
||||
else:
|
||||
print(f"Received response from {sender.name} to {recipient.name}: {message}")
|
||||
return message
|
||||
|
||||
|
||||
async def group_chat(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
joe_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
@ -67,16 +90,16 @@ async def group_chat(message: str) -> None:
|
||||
|
||||
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
|
||||
|
||||
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)
|
||||
response = runtime.send_message(TextMessage(content=message, source="host"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).body) # type: ignore
|
||||
await response
|
||||
|
||||
|
||||
async def orchestrator(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
developer_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
@ -117,17 +140,14 @@ async def orchestrator(message: str) -> None:
|
||||
)
|
||||
|
||||
response = runtime.send_message(
|
||||
ChatMessage(
|
||||
body=message,
|
||||
sender="customer",
|
||||
),
|
||||
TextMessage(content=message, source="customer"),
|
||||
chat,
|
||||
)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).body) # type: ignore
|
||||
print((await response).content) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Callable, Dict
|
||||
|
||||
import openai
|
||||
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
@ -15,22 +17,18 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
client: openai.AsyncClient,
|
||||
assistant_id: str,
|
||||
thread_id: str,
|
||||
tools: Dict[str, Callable[..., str]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._client = client
|
||||
self._assistant_id = assistant_id
|
||||
self._thread_id = thread_id
|
||||
self._current_session_window_length = 0
|
||||
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
|
||||
self._current_session_window_length = 1
|
||||
self._tools = tools or {}
|
||||
|
||||
# TODO: use require_response
|
||||
@message_handler(TextMessage)
|
||||
async def on_chat_message_with_cancellation(
|
||||
self, message: TextMessage, cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
print("---------------")
|
||||
print(f"{self.name} received message from {message.source}: {message.content}")
|
||||
print("---------------")
|
||||
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
# Save the message to the thread.
|
||||
_ = await self._client.beta.threads.messages.create(
|
||||
thread_id=self._thread_id,
|
||||
@ -43,7 +41,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
@message_handler(Reset)
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
# Reset the current session window.
|
||||
self._current_session_window_length = 0
|
||||
self._current_session_window_length = 1
|
||||
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
@ -61,6 +59,9 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
# TODO: handle other statuses.
|
||||
raise ValueError(f"Run did not complete successfully: {run}")
|
||||
|
||||
# Increment the current session window length.
|
||||
self._current_session_window_length += 1
|
||||
|
||||
# Get the last message from the run.
|
||||
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
|
||||
last_message_content = response.data[0].content
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import Any, List, Protocol, Sequence
|
||||
|
||||
from agnext.chat.types import Reset, RespondNow
|
||||
|
||||
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ...core import AgentRuntime, CancellationToken
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..types import Reset, RespondNow, TextMessage
|
||||
|
||||
|
||||
class Output(Protocol):
|
||||
class GroupChatOutput(Protocol):
|
||||
def on_message_received(self, message: Any) -> None: ...
|
||||
|
||||
def get_output(self) -> Any: ...
|
||||
@ -14,7 +14,7 @@ class Output(Protocol):
|
||||
def reset(self) -> None: ...
|
||||
|
||||
|
||||
class GroupChat(BaseChatAgent):
|
||||
class GroupChat(BaseChatAgent, TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@ -22,42 +22,43 @@ class GroupChat(BaseChatAgent):
|
||||
runtime: AgentRuntime,
|
||||
agents: Sequence[BaseChatAgent],
|
||||
num_rounds: int,
|
||||
output: Output,
|
||||
output: GroupChatOutput,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._agents = agents
|
||||
self._num_rounds = num_rounds
|
||||
self._history: List[Any] = []
|
||||
self._output = output
|
||||
super().__init__(name, description, runtime)
|
||||
|
||||
@property
|
||||
def subscriptions(self) -> Sequence[type]:
|
||||
agent_sublists = [agent.subscriptions for agent in self._agents]
|
||||
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
||||
|
||||
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
|
||||
if isinstance(message, Reset):
|
||||
# Reset the history.
|
||||
self._history = []
|
||||
# TODO: reset sub-agents?
|
||||
@message_handler(Reset)
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
self._history.clear()
|
||||
|
||||
if isinstance(message, RespondNow):
|
||||
# TODO reset...
|
||||
return self._output.get_output()
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
|
||||
return self._output.get_output()
|
||||
|
||||
@message_handler(TextMessage)
|
||||
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
|
||||
# Should this class disallow it?
|
||||
|
||||
self._history.append(message)
|
||||
round = 0
|
||||
prev_speaker = None
|
||||
|
||||
while round < self._num_rounds:
|
||||
# TODO: add support for advanced speaker selection.
|
||||
# Select speaker (round-robin for now).
|
||||
speaker = self._agents[round % len(self._agents)]
|
||||
|
||||
# Send the last message to all agents.
|
||||
for agent in [agent for agent in self._agents]:
|
||||
# Send the last message to all agents except the previous speaker.
|
||||
for agent in [agent for agent in self._agents if agent is not prev_speaker]:
|
||||
# TODO gather and await
|
||||
_ = await self._send_message(
|
||||
self._history[-1],
|
||||
@ -66,6 +67,7 @@ class GroupChat(BaseChatAgent):
|
||||
)
|
||||
# TODO handle if response is not None
|
||||
|
||||
# Request the speaker to speak.
|
||||
response = await self._send_message(
|
||||
RespondNow(),
|
||||
speaker,
|
||||
@ -73,12 +75,13 @@ class GroupChat(BaseChatAgent):
|
||||
)
|
||||
|
||||
if response is not None:
|
||||
# 4. Append the response to the history.
|
||||
# Append the response to the history.
|
||||
self._history.append(response)
|
||||
self._output.on_message_received(response)
|
||||
|
||||
# 6. Increment the round.
|
||||
# Increment the round.
|
||||
round += 1
|
||||
prev_speaker = speaker
|
||||
|
||||
output = self._output.get_output()
|
||||
self._output.reset()
|
||||
|
@ -6,7 +6,7 @@ from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handl
|
||||
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
|
||||
from ...core import AgentRuntime, CancellationToken
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..messages import ChatMessage
|
||||
from ..types import RespondNow, TextMessage
|
||||
|
||||
|
||||
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
||||
@ -27,26 +27,19 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
||||
self._max_turns = max_turns
|
||||
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
|
||||
self._max_retry_attempts_before_educated_guess = max_retry_attempts
|
||||
self._history: List[ChatMessage] = []
|
||||
self._history: List[TextMessage] = []
|
||||
|
||||
@message_handler(ChatMessage)
|
||||
async def on_chat_message(
|
||||
@message_handler(TextMessage)
|
||||
async def on_text_message(
|
||||
self,
|
||||
message: ChatMessage,
|
||||
message: TextMessage,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> ChatMessage | None:
|
||||
) -> TextMessage | None:
|
||||
# A task is received.
|
||||
task = message.body
|
||||
|
||||
if message.reset:
|
||||
# Reset the history.
|
||||
self._history = []
|
||||
if message.save_message_only:
|
||||
# TODO: what should we do with save_message_only messages for this pattern?
|
||||
return ChatMessage(body="OK", sender=self.name)
|
||||
task = message.content
|
||||
|
||||
# Prepare the task.
|
||||
team, names, facts, plan = await self._prepare_task(task, message.sender)
|
||||
team, names, facts, plan = await self._prepare_task(task, message.source)
|
||||
|
||||
# Main loop.
|
||||
total_turns = 0
|
||||
@ -74,11 +67,9 @@ Some additional points to consider:
|
||||
# Send the task specs to the team and signal a reset.
|
||||
for agent in self._agents:
|
||||
self._send_message(
|
||||
ChatMessage(
|
||||
body=task_specs,
|
||||
sender=self.name,
|
||||
save_message_only=True,
|
||||
reset=True,
|
||||
TextMessage(
|
||||
content=task_specs,
|
||||
source=self.name,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
@ -96,18 +87,13 @@ Some additional points to consider:
|
||||
stalled_turns = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reflect on the task.
|
||||
data = await self._reflect_on_task(task, team, names, ledger, message.sender)
|
||||
data = await self._reflect_on_task(task, team, names, ledger, message.source)
|
||||
|
||||
# Check if the request is satisfied.
|
||||
if data["is_request_satisfied"]["answer"]:
|
||||
return ChatMessage(
|
||||
body="The task has been successfully addressed.",
|
||||
sender=self.name,
|
||||
payload={
|
||||
"ledgers": ledgers,
|
||||
"status": "success",
|
||||
"reason": data["is_request_satisfied"]["reason"],
|
||||
},
|
||||
return TextMessage(
|
||||
content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}",
|
||||
source=self.name,
|
||||
)
|
||||
|
||||
# Update stalled turns.
|
||||
@ -121,7 +107,7 @@ Some additional points to consider:
|
||||
# In a retry, we need to rewrite the facts and the plan.
|
||||
|
||||
# Rewrite the facts.
|
||||
facts = await self._rewrite_facts(facts, ledger, message.sender)
|
||||
facts = await self._rewrite_facts(facts, ledger, message.source)
|
||||
|
||||
# Increment the retry attempts.
|
||||
retry_attempts += 1
|
||||
@ -129,20 +115,15 @@ Some additional points to consider:
|
||||
# Check if we should just guess.
|
||||
if retry_attempts > self._max_retry_attempts_before_educated_guess:
|
||||
# Make an educated guess.
|
||||
educated_guess = await self._educated_guess(facts, ledger, message.sender)
|
||||
educated_guess = await self._educated_guess(facts, ledger, message.source)
|
||||
if educated_guess["has_educated_guesses"]["answer"]:
|
||||
return ChatMessage(
|
||||
body="The task is addressed with an educated guess.",
|
||||
sender=self.name,
|
||||
payload={
|
||||
"ledgers": ledgers,
|
||||
"status": "educated_guess",
|
||||
"reason": educated_guess["has_educated_guesses"]["reason"],
|
||||
},
|
||||
return TextMessage(
|
||||
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
|
||||
source=self.name,
|
||||
)
|
||||
|
||||
# Come up with a new plan.
|
||||
plan = await self._rewrite_plan(team, ledger, message.sender)
|
||||
plan = await self._rewrite_plan(team, ledger, message.source)
|
||||
|
||||
# Exit the inner loop.
|
||||
break
|
||||
@ -152,28 +133,21 @@ Some additional points to consider:
|
||||
if subtask is None:
|
||||
subtask = ""
|
||||
|
||||
# Update agents.
|
||||
for agent in [agent for agent in self._agents]:
|
||||
_ = await self._send_message(
|
||||
TextMessage(content=subtask, source=self.name),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Find the speaker.
|
||||
try:
|
||||
speaker = next(agent for agent in self._agents if agent.name == data["next_speaker"]["answer"])
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||
|
||||
# Update all other agents.
|
||||
for agent in [agent for agent in self._agents if agent != speaker]:
|
||||
_ = await self._send_message(
|
||||
ChatMessage(
|
||||
body=subtask,
|
||||
sender=self.name,
|
||||
save_message_only=True,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Update the speaker and ask to speak.
|
||||
speaker_response = await self._send_message(
|
||||
ChatMessage(body=subtask, sender=self.name),
|
||||
speaker,
|
||||
)
|
||||
# As speaker to speak.
|
||||
speaker_response = await self._send_message(RespondNow(), speaker)
|
||||
|
||||
assert speaker_response is not None
|
||||
|
||||
@ -188,10 +162,9 @@ Some additional points to consider:
|
||||
# Update all other agents with the speaker's response.
|
||||
for agent in [agent for agent in self._agents if agent != speaker]:
|
||||
_ = await self._send_message(
|
||||
ChatMessage(
|
||||
body=speaker_response.body,
|
||||
sender=speaker_response.sender,
|
||||
save_message_only=True,
|
||||
TextMessage(
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
@ -199,22 +172,17 @@ Some additional points to consider:
|
||||
# Update the ledger.
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=speaker_response.body,
|
||||
source=speaker_response.sender,
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
)
|
||||
)
|
||||
|
||||
# Increment the total turns.
|
||||
total_turns += 1
|
||||
|
||||
return ChatMessage(
|
||||
body="The task was not addressed",
|
||||
sender=self.name,
|
||||
payload={
|
||||
"ledgers": ledgers,
|
||||
"status": "failure",
|
||||
"reason": "The maximum number of turns was reached.",
|
||||
},
|
||||
return TextMessage(
|
||||
content="The task was not addressed. The maximum number of turns was reached.",
|
||||
source=self.name,
|
||||
)
|
||||
|
||||
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user