mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-11-04 03:39:52 +00:00 
			
		
		
		
	Rename the `ChatMessage` and `AgentEvent` base classes to `BaseChatMessage` and `BaseAgentEvent`. Bring back the `ChatMessage` and `AgentEvent` as union of built-in concrete types to avoid breaking existing applications that depends on Pydantic serialization. Why? Many existing code uses containers like this: ```python class AppMessage(BaseModel): name: str message: ChatMessage # Serialization is this: m = AppMessage(...) m.model_dump_json() # Fields like HandoffMessage.target will be lost because it is now treated as a base class without content or target fields. ``` The assumption on `ChatMessage` or `AgentEvent` to be a union of concrete types could be in many existing code bases. So this PR brings back the union types, while keep method type hints such as those on `on_messages` to use the `BaseChatMessage` and `BaseAgentEvent` base classes for flexibility.
		
			
				
	
	
		
			143 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			143 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
from typing import AsyncGenerator, List, Sequence
 | 
						|
 | 
						|
import pytest
 | 
						|
import pytest_asyncio
 | 
						|
from autogen_agentchat.agents import BaseChatAgent
 | 
						|
from autogen_agentchat.base import Response
 | 
						|
from autogen_agentchat.messages import BaseChatMessage, TextMessage
 | 
						|
from autogen_agentchat.teams import RoundRobinGroupChat
 | 
						|
from autogen_core import AgentRuntime, CancellationToken, SingleThreadedAgentRuntime
 | 
						|
 | 
						|
 | 
						|
class TestAgent(BaseChatAgent):
 | 
						|
    """A test agent that does nothing."""
 | 
						|
 | 
						|
    def __init__(self, name: str, description: str) -> None:
 | 
						|
        super().__init__(name=name, description=description)
 | 
						|
        self._is_paused = False
 | 
						|
        self._tasks: List[asyncio.Task[None]] = []
 | 
						|
        self.counter = 0
 | 
						|
 | 
						|
    @property
 | 
						|
    def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
 | 
						|
        return [TextMessage]
 | 
						|
 | 
						|
    async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
 | 
						|
        assert not self._is_paused, "Agent is paused"
 | 
						|
 | 
						|
        async def _process() -> None:
 | 
						|
            # Simulate a repetitive task that runs forever.
 | 
						|
            while True:
 | 
						|
                if self._is_paused:
 | 
						|
                    await asyncio.sleep(0.1)
 | 
						|
                    continue
 | 
						|
                else:
 | 
						|
                    # Simulate a I/O operation that takes time, e.g., a browser operation.
 | 
						|
                    await asyncio.sleep(0.1)
 | 
						|
                    self.counter += 1
 | 
						|
 | 
						|
        curr_task = asyncio.create_task(_process())
 | 
						|
        self._tasks.append(curr_task)
 | 
						|
 | 
						|
        try:
 | 
						|
            # This will never return until the task is cancelled, at which point it will
 | 
						|
            # raise an exception.
 | 
						|
            await curr_task
 | 
						|
        except asyncio.CancelledError:
 | 
						|
            # The task was cancelled, so we can safely ignore this.
 | 
						|
            pass
 | 
						|
 | 
						|
        return Response(
 | 
						|
            chat_message=TextMessage(
 | 
						|
                source=self.name,
 | 
						|
                content="",
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
    async def on_reset(self, cancellation_token: CancellationToken) -> None:
 | 
						|
        self.counter = 0
 | 
						|
 | 
						|
    async def on_pause(self, cancellation_token: CancellationToken) -> None:
 | 
						|
        self._is_paused = True
 | 
						|
 | 
						|
    async def on_resume(self, cancellation_token: CancellationToken) -> None:
 | 
						|
        self._is_paused = False
 | 
						|
 | 
						|
    async def close(self) -> None:
 | 
						|
        # Cancel all tasks and wait for them to finish.
 | 
						|
        while self._tasks:
 | 
						|
            task = self._tasks.pop()
 | 
						|
            task.cancel()
 | 
						|
            try:
 | 
						|
                await task
 | 
						|
            except asyncio.CancelledError:
 | 
						|
                pass
 | 
						|
 | 
						|
 | 
						|
@pytest_asyncio.fixture(params=["single_threaded", "embedded"])  # type: ignore
 | 
						|
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
 | 
						|
    if request.param == "single_threaded":
 | 
						|
        runtime = SingleThreadedAgentRuntime()
 | 
						|
        runtime.start()
 | 
						|
        yield runtime
 | 
						|
        await runtime.stop()
 | 
						|
    elif request.param == "embedded":
 | 
						|
        yield None
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.asyncio
 | 
						|
async def test_group_chat_pause_resume(runtime: AgentRuntime | None) -> None:
 | 
						|
    agent = TestAgent(name="test_agent", description="test agent")
 | 
						|
 | 
						|
    team = RoundRobinGroupChat([agent], runtime=runtime, max_turns=1)
 | 
						|
 | 
						|
    # Run the team in a separate task.
 | 
						|
    team_task = asyncio.create_task(team.run())
 | 
						|
 | 
						|
    # Get the current counter.
 | 
						|
    curr_counter = agent.counter
 | 
						|
 | 
						|
    # Let the agent process the counter for a while.
 | 
						|
    await asyncio.sleep(1)
 | 
						|
 | 
						|
    # Check that the agent's counter has increased.
 | 
						|
    assert curr_counter < agent.counter
 | 
						|
    curr_counter = agent.counter
 | 
						|
 | 
						|
    # Pause the team.
 | 
						|
    await team.pause()
 | 
						|
 | 
						|
    # Wait for a while for the agent to process the pause.
 | 
						|
    await asyncio.sleep(1)
 | 
						|
 | 
						|
    # Get the current counter value.
 | 
						|
    curr_counter = agent.counter
 | 
						|
 | 
						|
    # Wait for a while.
 | 
						|
    await asyncio.sleep(1)
 | 
						|
 | 
						|
    # Check that the agent's counter has not increased.
 | 
						|
    assert curr_counter == agent.counter
 | 
						|
 | 
						|
    # Resume the agent.
 | 
						|
    await team.resume()
 | 
						|
 | 
						|
    # Wait for a while for the agent to process the resume.
 | 
						|
    await asyncio.sleep(1)
 | 
						|
 | 
						|
    # Get the current counter value.
 | 
						|
    curr_counter = agent.counter
 | 
						|
 | 
						|
    # Wait for a while.
 | 
						|
    await asyncio.sleep(1)
 | 
						|
 | 
						|
    # Check that the agent's counter has increased.
 | 
						|
    assert curr_counter < agent.counter
 | 
						|
 | 
						|
    # Clean up -- force the agent to respond and terminate the team.
 | 
						|
    await agent.close()
 | 
						|
 | 
						|
    # Wait for the team to terminate.
 | 
						|
    await team_task
 |