mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-03 03:57:36 +00:00
Sequential processing for group chat participant using SequentialRoutedAgent (#663)
This commit is contained in:
parent
18efc2314a
commit
0fa680577e
@ -1,13 +1,14 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from autogen_core.base import MessageContext
|
from autogen_core.base import MessageContext
|
||||||
from autogen_core.components import DefaultTopicId, RoutedAgent, event
|
from autogen_core.components import DefaultTopicId, event
|
||||||
|
|
||||||
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage
|
from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage
|
||||||
from ._events import ContentPublishEvent, ContentRequestEvent
|
from ._events import ContentPublishEvent, ContentRequestEvent
|
||||||
|
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAgentContainer(RoutedAgent):
|
class BaseChatAgentContainer(SequentialRoutedAgent):
|
||||||
"""A core agent class that delegates message handling to an
|
"""A core agent class that delegates message handling to an
|
||||||
:class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a
|
:class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a
|
||||||
group chat team.
|
group chat team.
|
||||||
|
@ -2,13 +2,14 @@ import sys
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from autogen_core.base import MessageContext, TopicId
|
from autogen_core.base import MessageContext, TopicId
|
||||||
from autogen_core.components import RoutedAgent, event
|
from autogen_core.components import event
|
||||||
|
|
||||||
from ...agents import MultiModalMessage, StopMessage, TextMessage
|
from ...agents import MultiModalMessage, StopMessage, TextMessage
|
||||||
from ._events import ContentPublishEvent, ContentRequestEvent
|
from ._events import ContentPublishEvent, ContentRequestEvent
|
||||||
|
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||||
|
|
||||||
|
|
||||||
class BaseGroupChatManager(RoutedAgent):
|
class BaseGroupChatManager(SequentialRoutedAgent):
|
||||||
"""Base class for a group chat manager that manages a group chat with multiple participants.
|
"""Base class for a group chat manager that manages a group chat with multiple participants.
|
||||||
|
|
||||||
It is the responsibility of the caller to ensure:
|
It is the responsibility of the caller to ensure:
|
||||||
|
@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from autogen_core.base import MessageContext
|
||||||
|
from autogen_core.components import RoutedAgent
|
||||||
|
|
||||||
|
|
||||||
|
class FIFOLock:
|
||||||
|
"""A lock that ensures coroutines acquire the lock in the order they request it."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._queue = asyncio.Queue[asyncio.Event]()
|
||||||
|
self._locked = False
|
||||||
|
|
||||||
|
async def acquire(self) -> None:
|
||||||
|
# If the lock is not held by any coroutine, set the lock to be held
|
||||||
|
# by the current coroutine.
|
||||||
|
if not self._locked:
|
||||||
|
self._locked = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the lock is held by another coroutine, create an event and put it
|
||||||
|
# in the queue. Wait for the event to be set.
|
||||||
|
event = asyncio.Event()
|
||||||
|
await self._queue.put(event)
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
if not self._queue.empty():
|
||||||
|
# If there are events in the queue, get the next event and set it.
|
||||||
|
next_event = self._queue.get_nowait()
|
||||||
|
next_event.set()
|
||||||
|
else:
|
||||||
|
# If there are no events in the queue, release the lock.
|
||||||
|
self._locked = False
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialRoutedAgent(RoutedAgent):
|
||||||
|
"""A subclass of :class:`autogen_core.components.RoutedAgent` that ensures
|
||||||
|
messages are handled sequentially in the order they arrive."""
|
||||||
|
|
||||||
|
def __init__(self, description: str) -> None:
|
||||||
|
super().__init__(description=description)
|
||||||
|
self._fifo_lock = FIFOLock()
|
||||||
|
|
||||||
|
async def on_message(self, message: Any, ctx: MessageContext) -> Any | None:
|
||||||
|
await self._fifo_lock.acquire()
|
||||||
|
try:
|
||||||
|
return await super().on_message(message, ctx)
|
||||||
|
finally:
|
||||||
|
self._fifo_lock.release()
|
@ -0,0 +1,42 @@
|
|||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from autogen_agentchat.teams.group_chat._sequential_routed_agent import SequentialRoutedAgent
|
||||||
|
from autogen_core.application import SingleThreadedAgentRuntime
|
||||||
|
from autogen_core.base import AgentId, MessageContext
|
||||||
|
from autogen_core.components import DefaultTopicId, default_subscription, message_handler
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@default_subscription
|
||||||
|
class TestAgent(SequentialRoutedAgent):
|
||||||
|
def __init__(self, description: str) -> None:
|
||||||
|
super().__init__(description=description)
|
||||||
|
self.messages: List[Message] = []
|
||||||
|
|
||||||
|
@message_handler
|
||||||
|
async def handle_content_publish(self, message: Message, ctx: MessageContext) -> None:
|
||||||
|
# Sleep a random amount of time to simulate processing time.
|
||||||
|
await asyncio.sleep(random.random() / 100)
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sequential_routed_agent() -> None:
|
||||||
|
runtime = SingleThreadedAgentRuntime()
|
||||||
|
runtime.start()
|
||||||
|
await TestAgent.register(runtime, type="test_agent", factory=lambda: TestAgent(description="Test Agent"))
|
||||||
|
test_agent_id = AgentId(type="test_agent", key="default")
|
||||||
|
for i in range(100):
|
||||||
|
await runtime.publish_message(Message(content=f"{i}"), topic_id=DefaultTopicId())
|
||||||
|
await runtime.stop_when_idle()
|
||||||
|
test_agent = await runtime.try_get_underlying_agent_instance(test_agent_id, TestAgent)
|
||||||
|
for i in range(100):
|
||||||
|
assert test_agent.messages[i].content == f"{i}"
|
Loading…
x
Reference in New Issue
Block a user