diff --git a/python/packages/autogen-core/src/autogen_core/_agent_proxy.py b/python/packages/autogen-core/src/autogen_core/_agent_proxy.py index 09cf3c1de..e23022fb2 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_proxy.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_proxy.py @@ -33,12 +33,14 @@ class AgentProxy: *, sender: AgentId, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> Any: return await self._runtime.send_message( message, recipient=self._agent, sender=sender, cancellation_token=cancellation_token, + message_id=message_id, ) async def save_state(self) -> Mapping[str, Any]: diff --git a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py index 5a3ebefbc..8f1b3ae7d 100644 --- a/python/packages/autogen-core/src/autogen_core/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_agent_runtime.py @@ -26,6 +26,7 @@ class AgentRuntime(Protocol): *, sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> Any: """Send a message to an agent and get a response. diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 79bffd36d..cfefb4ab7 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -121,6 +121,7 @@ class BaseAgent(ABC, Agent): recipient: AgentId, *, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> Any: """See :py:meth:`autogen_core.AgentRuntime.send_message` for more information.""" if cancellation_token is None: @@ -131,6 +132,7 @@ class BaseAgent(ABC, Agent): sender=self.id, recipient=recipient, cancellation_token=cancellation_token, + message_id=message_id, ) async def publish_message( diff --git a/python/packages/autogen-core/src/autogen_core/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/_closure_agent.py index 8f93b4f2b..5e172ee73 100644 --- a/python/packages/autogen-core/src/autogen_core/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_closure_agent.py @@ -61,6 +61,7 @@ class ClosureContext(Protocol): recipient: AgentId, *, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> Any: ... async def publish_message( diff --git a/python/packages/autogen-core/src/autogen_core/_message_handler_context.py b/python/packages/autogen-core/src/autogen_core/_message_handler_context.py index b0f08ac8c..9e5a6a97d 100644 --- a/python/packages/autogen-core/src/autogen_core/_message_handler_context.py +++ b/python/packages/autogen-core/src/autogen_core/_message_handler_context.py @@ -8,7 +8,7 @@ from ._agent_id import AgentId class MessageHandlerContext: def __init__(self) -> None: raise RuntimeError( - "MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation." + "MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling." ) _MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT") diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 37c84ca84..9c292b9f2 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -4,7 +4,6 @@ import asyncio import inspect import logging import sys -import threading import uuid import warnings from asyncio import CancelledError, Future, Queue, Task @@ -14,6 +13,15 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set from opentelemetry.trace import TracerProvider +from .logging import ( + AgentConstructionExceptionEvent, + DeliveryStage, + MessageDroppedEvent, + MessageEvent, + MessageHandlerExceptionEvent, + MessageKind, +) + if sys.version_info >= (3, 13): from asyncio import Queue, QueueShutDown else: @@ -32,7 +40,7 @@ from ._intervention import DropMessage, InterventionHandler from ._message_context import MessageContext from ._message_handler_context import MessageHandlerContext from ._runtime_impl_helpers import SubscriptionManager, get_impl -from ._serialization import MessageSerializer, SerializationRegistry +from ._serialization import JSON_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry from ._subscription import Subscription from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata from ._topic import TopicId @@ -70,6 +78,7 @@ class SendMessageEnvelope: future: Future[Any] cancellation_token: CancellationToken metadata: EnvelopeMetadata | None = None + message_id: str @dataclass(kw_only=True) @@ -87,25 +96,6 @@ P = ParamSpec("P") T = TypeVar("T", bound=Agent) -class Counter: - def __init__(self) -> None: - self._count: int = 0 - self.threadLock = threading.Lock() - - def increment(self) -> None: - self.threadLock.acquire() - self._count += 1 - self.threadLock.release() - - def get(self) -> int: - return self._count - - def decrement(self) -> None: - self.threadLock.acquire() - self._count -= 1 - self.threadLock.release() - - class RunContext: def __init__(self, runtime: SingleThreadedAgentRuntime) -> None: self._runtime = runtime @@ -194,19 +184,23 @@ class SingleThreadedAgentRuntime(AgentRuntime): *, sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> Any: if cancellation_token is None: cancellation_token = CancellationToken() - # event_logger.info( - # MessageEvent( - # payload=message, - # sender=sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.SEND, - # ) - # ) + if message_id is None: + message_id = str(uuid.uuid4()) + + event_logger.info( + MessageEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=recipient, + kind=MessageKind.DIRECT, + delivery_stage=DeliveryStage.SEND, + ) + ) with self._tracer_helper.trace_block( "create", @@ -229,6 +223,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): cancellation_token=cancellation_token, sender=sender, metadata=get_telemetry_envelope_metadata(), + message_id=message_id, ) ) @@ -259,15 +254,15 @@ class SingleThreadedAgentRuntime(AgentRuntime): if message_id is None: message_id = str(uuid.uuid4()) - # event_logger.info( - # MessageEvent( - # payload=message, - # sender=sender, - # receiver=None, - # kind=MessageKind.PUBLISH, - # delivery_stage=DeliveryStage.SEND, - # ) - # ) + event_logger.info( + MessageEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=topic_id, + kind=MessageKind.PUBLISH, + delivery_stage=DeliveryStage.SEND, + ) + ) await self._message_queue.put( PublishMessageEnvelope( @@ -295,32 +290,31 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): recipient = message_envelope.recipient - # todo: check if recipient is in the known namespaces - # assert recipient in self._agents + + if recipient.type not in self._known_agent_names: + raise LookupError(f"Agent type '{recipient.type}' does not exist.") try: - # TODO use id - sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown" + sender_id = str(message_envelope.sender) if message_envelope.sender is not None else "Unknown" logger.info( - f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}" + f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_id}" + ) + event_logger.info( + MessageEvent( + payload=self._try_serialize(message_envelope.message), + sender=message_envelope.sender, + receiver=recipient, + kind=MessageKind.DIRECT, + delivery_stage=DeliveryStage.DELIVER, + ) ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) recipient_agent = await self._get_agent(recipient) message_context = MessageContext( sender=message_envelope.sender, topic_id=None, is_rpc=True, cancellation_token=message_envelope.cancellation_token, - # Will be fixed when send API removed - message_id="NOT_DEFINED_TODO_FIX", + message_id=message_envelope.message_id, ) with MessageHandlerContext.populate_context(recipient_agent.id): response = await recipient_agent.on_message( @@ -331,12 +325,36 @@ class SingleThreadedAgentRuntime(AgentRuntime): if not message_envelope.future.cancelled(): message_envelope.future.set_exception(e) self._message_queue.task_done() + event_logger.info( + MessageHandlerExceptionEvent( + payload=self._try_serialize(message_envelope.message), + handling_agent=recipient, + exception=e, + ) + ) return except BaseException as e: message_envelope.future.set_exception(e) self._message_queue.task_done() + event_logger.info( + MessageHandlerExceptionEvent( + payload=self._try_serialize(message_envelope.message), + handling_agent=recipient, + exception=e, + ) + ) return + event_logger.info( + MessageEvent( + payload=self._try_serialize(response), + sender=message_envelope.recipient, + receiver=message_envelope.sender, + kind=MessageKind.RESPOND, + delivery_stage=DeliveryStage.SEND, + ) + ) + await self._message_queue.put( ResponseMessageEnvelope( message=response, @@ -365,15 +383,15 @@ class SingleThreadedAgentRuntime(AgentRuntime): logger.info( f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}" ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=agent, - # kind=MessageKind.PUBLISH, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) + event_logger.info( + MessageEvent( + payload=self._try_serialize(message_envelope.message), + sender=message_envelope.sender, + receiver=None, + kind=MessageKind.PUBLISH, + delivery_stage=DeliveryStage.DELIVER, + ) + ) message_context = MessageContext( sender=message_envelope.sender, topic_id=message_envelope.topic_id, @@ -386,20 +404,29 @@ class SingleThreadedAgentRuntime(AgentRuntime): async def _on_message(agent: Agent, message_context: MessageContext) -> Any: with self._tracer_helper.trace_block("process", agent.id, parent=None): with MessageHandlerContext.populate_context(agent.id): - return await agent.on_message( - message_envelope.message, - ctx=message_context, - ) + try: + return await agent.on_message( + message_envelope.message, + ctx=message_context, + ) + except BaseException as e: + logger.error(f"Error processing publish message for {agent.id}", exc_info=True) + event_logger.info( + MessageHandlerExceptionEvent( + payload=self._try_serialize(message_envelope.message), + handling_agent=agent.id, + exception=e, + ) + ) + raise future = _on_message(agent, message_context) responses.append(future) await asyncio.gather(*responses) - except BaseException as e: - # Ignore cancelled errors from logs - if isinstance(e, CancelledError): - return - logger.error("Error processing publish message", exc_info=True) + except BaseException: + # Ignore exceptions raised during publishing. We've already logged them above. + pass finally: self._message_queue.task_done() # TODO if responses are given for a publish @@ -414,18 +441,18 @@ class SingleThreadedAgentRuntime(AgentRuntime): logger.info( f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}" ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=message_envelope.recipient, - # kind=MessageKind.RESPOND, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - self._message_queue.task_done() + event_logger.info( + MessageEvent( + payload=self._try_serialize(message_envelope.message), + sender=message_envelope.sender, + receiver=message_envelope.recipient, + kind=MessageKind.RESPOND, + delivery_stage=DeliveryStage.DELIVER, + ) + ) if not message_envelope.future.cancelled(): message_envelope.future.set_result(message_envelope.message) + self._message_queue.task_done() @deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.") async def process_next(self) -> None: @@ -453,6 +480,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): future.set_exception(e) return if temp_message is DropMessage or isinstance(temp_message, DropMessage): + event_logger.info( + MessageDroppedEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=recipient, + kind=MessageKind.DIRECT, + ) + ) future.set_exception(MessageDroppedException()) return @@ -463,6 +498,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): case PublishMessageEnvelope( message=message, sender=sender, + topic_id=topic_id, ): if self._intervention_handlers is not None: for handler in self._intervention_handlers: @@ -477,7 +513,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) return if temp_message is DropMessage or isinstance(temp_message, DropMessage): - # TODO log message dropped + event_logger.info( + MessageDroppedEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=topic_id, + kind=MessageKind.PUBLISH, + ) + ) return message_envelope.message = temp_message @@ -495,6 +538,14 @@ class SingleThreadedAgentRuntime(AgentRuntime): future.set_exception(e) return if temp_message is DropMessage or isinstance(temp_message, DropMessage): + event_logger.info( + MessageDroppedEvent( + payload=self._try_serialize(message), + sender=sender, + receiver=recipient, + kind=MessageKind.RESPOND, + ) + ) future.set_exception(MessageDroppedException()) return message_envelope.message = temp_message @@ -603,23 +654,34 @@ class SingleThreadedAgentRuntime(AgentRuntime): agent_id: AgentId, ) -> T: with AgentInstantiationContext.populate_context((self, agent_id)): - if len(inspect.signature(agent_factory).parameters) == 0: - factory_one = cast(Callable[[], T], agent_factory) - agent = factory_one() - elif len(inspect.signature(agent_factory).parameters) == 2: - warnings.warn( - "Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.", - stacklevel=2, + try: + if len(inspect.signature(agent_factory).parameters) == 0: + factory_one = cast(Callable[[], T], agent_factory) + agent = factory_one() + elif len(inspect.signature(agent_factory).parameters) == 2: + warnings.warn( + "Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.", + stacklevel=2, + ) + factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) + agent = factory_two(self, agent_id) + else: + raise ValueError("Agent factory must take 0 or 2 arguments.") + + if inspect.isawaitable(agent): + return cast(T, await agent) + + return agent + + except BaseException as e: + event_logger.info( + AgentConstructionExceptionEvent( + agent_id=agent_id, + exception=e, + ) ) - factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) - agent = factory_two(self, agent_id) - else: - raise ValueError("Agent factory must take 0 or 2 arguments.") - - if inspect.isawaitable(agent): - return cast(T, await agent) - - return agent + logger.error(f"Error constructing agent {agent_id}", exc_info=True) + raise async def _get_agent(self, agent_id: AgentId) -> Agent: if agent_id in self._instantiated_agents: @@ -666,3 +728,12 @@ class SingleThreadedAgentRuntime(AgentRuntime): def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: self._serialization_registry.add_serializer(serializer) + + def _try_serialize(self, message: Any) -> str: + try: + type_name = self._serialization_registry.type_name(message) + return self._serialization_registry.serialize( + message, type_name=type_name, data_content_type=JSON_DATA_CONTENT_TYPE + ).decode("utf-8") + except ValueError: + return "Message could not be serialized" diff --git a/python/packages/autogen-core/src/autogen_core/logging.py b/python/packages/autogen-core/src/autogen_core/logging.py index 5e3870203..11bb46c04 100644 --- a/python/packages/autogen-core/src/autogen_core/logging.py +++ b/python/packages/autogen-core/src/autogen_core/logging.py @@ -2,7 +2,8 @@ import json from enum import Enum from typing import Any, cast -from autogen_core import AgentId +from ._agent_id import AgentId +from ._topic import TopicId class LLMCallEvent: @@ -57,9 +58,9 @@ class MessageEvent: def __init__( self, *, - payload: Any, + payload: str, sender: AgentId | None, - receiver: AgentId | None, + receiver: AgentId | TopicId | None, kind: MessageKind, delivery_stage: DeliveryStage, **kwargs: Any, @@ -68,18 +69,70 @@ class MessageEvent: self.kwargs["payload"] = payload self.kwargs["sender"] = None if sender is None else str(sender) self.kwargs["receiver"] = None if receiver is None else str(receiver) - self.kwargs["kind"] = kind - self.kwargs["delivery_stage"] = delivery_stage + self.kwargs["kind"] = str(kind) + self.kwargs["delivery_stage"] = str(delivery_stage) self.kwargs["type"] = "Message" - @property - def prompt_tokens(self) -> int: - return cast(int, self.kwargs["prompt_tokens"]) - - @property - def completion_tokens(self) -> int: - return cast(int, self.kwargs["completion_tokens"]) - + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class MessageDroppedEvent: + def __init__( + self, + *, + payload: str, + sender: AgentId | None, + receiver: AgentId | TopicId | None, + kind: MessageKind, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["payload"] = payload + self.kwargs["sender"] = None if sender is None else str(sender) + self.kwargs["receiver"] = None if receiver is None else str(receiver) + self.kwargs["kind"] = str(kind) + self.kwargs["type"] = "MessageDropped" + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class MessageHandlerExceptionEvent: + def __init__( + self, + *, + payload: str, + handling_agent: AgentId, + exception: BaseException, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["payload"] = payload + self.kwargs["handling_agent"] = str(handling_agent) + self.kwargs["exception"] = str(exception) + self.kwargs["type"] = "MessageHandlerException" + + # This must output the event in a json serializable format + def __str__(self) -> str: + return json.dumps(self.kwargs) + + +class AgentConstructionExceptionEvent: + def __init__( + self, + *, + agent_id: AgentId, + exception: BaseException, + **kwargs: Any, + ) -> None: + self.kwargs = kwargs + self.kwargs["agent_id"] = str(agent_id) + self.kwargs["exception"] = str(exception) + self.kwargs["type"] = "AgentConstructionException" + # This must output the event in a json serializable format def __str__(self) -> str: return json.dumps(self.kwargs) diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index 440c839fa..15d7d81cc 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -78,7 +78,7 @@ async def test_message_handler_router() -> None: @dataclass -class TestMessage: +class MyMessage: value: str @@ -89,15 +89,15 @@ class RoutedAgentMessageCustomMatch(RoutedAgent): self.handler_two_called = False @staticmethod - def match_one(message: TestMessage, ctx: MessageContext) -> bool: + def match_one(message: MyMessage, ctx: MessageContext) -> bool: return message.value == "one" @message_handler(match=match_one) - async def handler_one(self, message: TestMessage, ctx: MessageContext) -> None: + async def handler_one(self, message: MyMessage, ctx: MessageContext) -> None: self.handler_one_called = True - @message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore - async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None: + @message_handler(match=cast(Callable[[MyMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore + async def handler_two(self, message: MyMessage, ctx: MessageContext) -> None: self.handler_two_called = True @@ -113,14 +113,14 @@ async def test_routed_agent_message_matching() -> None: assert agent.handler_two_called is False runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(MyMessage("one"), recipient=agent_id) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) assert agent.handler_one_called is True assert agent.handler_two_called is False runtime.start() - await runtime.send_message(TestMessage("two"), recipient=agent_id) + await runtime.send_message(MyMessage("two"), recipient=agent_id) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) assert agent.handler_one_called is True @@ -133,11 +133,11 @@ class EventAgent(RoutedAgent): self.num_calls = [0, 0] @event(match=lambda msg, ctx: msg.value == "one") # type: ignore - async def on_event_one(self, message: TestMessage, ctx: MessageContext) -> None: + async def on_event_one(self, message: MyMessage, ctx: MessageContext) -> None: self.num_calls[0] += 1 @event(match=lambda msg, ctx: msg.value == "two") # type: ignore - async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None: + async def on_event_two(self, message: MyMessage, ctx: MessageContext) -> None: self.num_calls[1] += 1 @@ -150,7 +150,7 @@ async def test_event() -> None: # Send a broadcast message. runtime.start() - await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default")) + await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default")) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) assert agent.num_calls[0] == 1 @@ -158,7 +158,7 @@ async def test_event() -> None: # Send another broadcast message. runtime.start() - await runtime.publish_message(TestMessage("two"), topic_id=TopicId("default", "default")) + await runtime.publish_message(MyMessage("two"), topic_id=TopicId("default", "default")) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) assert agent.num_calls[0] == 1 @@ -166,7 +166,7 @@ async def test_event() -> None: # Send an RPC message, expect no change. runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(MyMessage("one"), recipient=agent_id) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) assert agent.num_calls[0] == 1 @@ -179,12 +179,12 @@ class RPCAgent(RoutedAgent): self.num_calls = [0, 0] @rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore - async def on_rpc_one(self, message: TestMessage, ctx: MessageContext) -> TestMessage: + async def on_rpc_one(self, message: MyMessage, ctx: MessageContext) -> MyMessage: self.num_calls[0] += 1 return message @rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore - async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMessage: + async def on_rpc_two(self, message: MyMessage, ctx: MessageContext) -> MyMessage: self.num_calls[1] += 1 return message @@ -198,7 +198,7 @@ async def test_rpc() -> None: # Send an RPC message. runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(MyMessage("one"), recipient=agent_id) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 @@ -206,7 +206,7 @@ async def test_rpc() -> None: # Send another RPC message. runtime.start() - await runtime.send_message(TestMessage("two"), recipient=agent_id) + await runtime.send_message(MyMessage("two"), recipient=agent_id) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 @@ -214,7 +214,7 @@ async def test_rpc() -> None: # Send a broadcast message, expect no change. runtime.start() - await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default")) + await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default")) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index e5b04bd87..16de5ccc1 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -20,10 +20,10 @@ from autogen_test_utils import ( MessageType, NoopAgent, ) -from autogen_test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider +from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider from opentelemetry.sdk.trace import TracerProvider -test_exporter = TestExporter() +test_exporter = MyTestExporter() @pytest.fixture @@ -88,7 +88,7 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non @pytest.mark.asyncio -async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptureFixture) -> None: +async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None: runtime = SingleThreadedAgentRuntime() runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType)) @@ -103,8 +103,9 @@ async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptur runtime.start() await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default")) await runtime.stop_when_idle() - # Check if logger has the exception. - assert any("Error processing publish message" in e.message for e in caplog.records) + + # Check if logger has the exception. + assert any("Error constructing agent" in e.message for e in caplog.records) @pytest.mark.asyncio diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index d331be766..b39ec04a3 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -339,7 +339,9 @@ class GrpcWorkerAgentRuntime(AgentRuntime): *, sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> Any: + # TODO: use message_id if not self._running: raise ValueError("Runtime must be running when sending message.") if self._host_connection is None: diff --git a/python/packages/autogen-test-utils/pyproject.toml b/python/packages/autogen-test-utils/pyproject.toml index 22b0ba073..76877e5a9 100644 --- a/python/packages/autogen-test-utils/pyproject.toml +++ b/python/packages/autogen-test-utils/pyproject.toml @@ -7,8 +7,10 @@ name = "autogen-test-utils" version = "0.0.0" license = {file = "LICENSE-CODE"} requires-python = ">=3.10" -dependencies = ["autogen-core", - +dependencies = [ + "autogen-core", + "pytest", + "opentelemetry-sdk>=1.27.0", ] [tool.ruff] diff --git a/python/packages/autogen-test-utils/src/autogen_test_utils/telemetry_test_utils.py b/python/packages/autogen-test-utils/src/autogen_test_utils/telemetry_test_utils.py index 994e21042..00a13a6ea 100644 --- a/python/packages/autogen-test-utils/src/autogen_test_utils/telemetry_test_utils.py +++ b/python/packages/autogen-test-utils/src/autogen_test_utils/telemetry_test_utils.py @@ -1,10 +1,11 @@ from typing import List, Sequence +import pytest from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult -class TestExporter(SpanExporter): +class MyTestExporter(SpanExporter): def __init__(self) -> None: self.exported_spans: List[ReadableSpan] = [] @@ -24,7 +25,7 @@ class TestExporter(SpanExporter): return self.exported_spans -def get_test_tracer_provider(exporter: TestExporter) -> TracerProvider: +def get_test_tracer_provider(exporter: MyTestExporter) -> TracerProvider: tracer_provider = TracerProvider() tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) return tracer_provider diff --git a/python/uv.lock b/python/uv.lock index f9b7b66d7..2a2a012ef 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -627,10 +627,16 @@ version = "0.0.0" source = { editable = "packages/autogen-test-utils" } dependencies = [ { name = "autogen-core" }, + { name = "opentelemetry-sdk" }, + { name = "pytest" }, ] [package.metadata] -requires-dist = [{ name = "autogen-core", editable = "packages/autogen-core" }] +requires-dist = [ + { name = "autogen-core", editable = "packages/autogen-core" }, + { name = "opentelemetry-sdk", specifier = ">=1.27.0" }, + { name = "pytest" }, +] [[package]] name = "autogenstudio"