Update logged events, add message id to send message (#4868)

* Update logged events

* add message_id

* serialize payload for log

* fix pytest warning

* serialization

* fix test

* lock

* fix warning and test
This commit is contained in:
Jack Gerrits 2024-12-31 15:11:48 -05:00 committed by GitHub
parent d2a74de3ad
commit e6ac2f37fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 283 additions and 141 deletions

View File

@ -33,12 +33,14 @@ class AgentProxy:
*, *,
sender: AgentId, sender: AgentId,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ) -> Any:
return await self._runtime.send_message( return await self._runtime.send_message(
message, message,
recipient=self._agent, recipient=self._agent,
sender=sender, sender=sender,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
message_id=message_id,
) )
async def save_state(self) -> Mapping[str, Any]: async def save_state(self) -> Mapping[str, Any]:

View File

@ -26,6 +26,7 @@ class AgentRuntime(Protocol):
*, *,
sender: AgentId | None = None, sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ) -> Any:
"""Send a message to an agent and get a response. """Send a message to an agent and get a response.

View File

@ -121,6 +121,7 @@ class BaseAgent(ABC, Agent):
recipient: AgentId, recipient: AgentId,
*, *,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ) -> Any:
"""See :py:meth:`autogen_core.AgentRuntime.send_message` for more information.""" """See :py:meth:`autogen_core.AgentRuntime.send_message` for more information."""
if cancellation_token is None: if cancellation_token is None:
@ -131,6 +132,7 @@ class BaseAgent(ABC, Agent):
sender=self.id, sender=self.id,
recipient=recipient, recipient=recipient,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
message_id=message_id,
) )
async def publish_message( async def publish_message(

View File

@ -61,6 +61,7 @@ class ClosureContext(Protocol):
recipient: AgentId, recipient: AgentId,
*, *,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ... ) -> Any: ...
async def publish_message( async def publish_message(

View File

@ -8,7 +8,7 @@ from ._agent_id import AgentId
class MessageHandlerContext: class MessageHandlerContext:
def __init__(self) -> None: def __init__(self) -> None:
raise RuntimeError( raise RuntimeError(
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation." "MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling."
) )
_MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT") _MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT")

View File

@ -4,7 +4,6 @@ import asyncio
import inspect import inspect
import logging import logging
import sys import sys
import threading
import uuid import uuid
import warnings import warnings
from asyncio import CancelledError, Future, Queue, Task from asyncio import CancelledError, Future, Queue, Task
@ -14,6 +13,15 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set
from opentelemetry.trace import TracerProvider from opentelemetry.trace import TracerProvider
from .logging import (
AgentConstructionExceptionEvent,
DeliveryStage,
MessageDroppedEvent,
MessageEvent,
MessageHandlerExceptionEvent,
MessageKind,
)
if sys.version_info >= (3, 13): if sys.version_info >= (3, 13):
from asyncio import Queue, QueueShutDown from asyncio import Queue, QueueShutDown
else: else:
@ -32,7 +40,7 @@ from ._intervention import DropMessage, InterventionHandler
from ._message_context import MessageContext from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext from ._message_handler_context import MessageHandlerContext
from ._runtime_impl_helpers import SubscriptionManager, get_impl from ._runtime_impl_helpers import SubscriptionManager, get_impl
from ._serialization import MessageSerializer, SerializationRegistry from ._serialization import JSON_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry
from ._subscription import Subscription from ._subscription import Subscription
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
from ._topic import TopicId from ._topic import TopicId
@ -70,6 +78,7 @@ class SendMessageEnvelope:
future: Future[Any] future: Future[Any]
cancellation_token: CancellationToken cancellation_token: CancellationToken
metadata: EnvelopeMetadata | None = None metadata: EnvelopeMetadata | None = None
message_id: str
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -87,25 +96,6 @@ P = ParamSpec("P")
T = TypeVar("T", bound=Agent) T = TypeVar("T", bound=Agent)
class Counter:
def __init__(self) -> None:
self._count: int = 0
self.threadLock = threading.Lock()
def increment(self) -> None:
self.threadLock.acquire()
self._count += 1
self.threadLock.release()
def get(self) -> int:
return self._count
def decrement(self) -> None:
self.threadLock.acquire()
self._count -= 1
self.threadLock.release()
class RunContext: class RunContext:
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None: def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime self._runtime = runtime
@ -194,19 +184,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
*, *,
sender: AgentId | None = None, sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ) -> Any:
if cancellation_token is None: if cancellation_token is None:
cancellation_token = CancellationToken() cancellation_token = CancellationToken()
# event_logger.info( if message_id is None:
# MessageEvent( message_id = str(uuid.uuid4())
# payload=message,
# sender=sender, event_logger.info(
# receiver=recipient, MessageEvent(
# kind=MessageKind.DIRECT, payload=self._try_serialize(message),
# delivery_stage=DeliveryStage.SEND, sender=sender,
# ) receiver=recipient,
# ) kind=MessageKind.DIRECT,
delivery_stage=DeliveryStage.SEND,
)
)
with self._tracer_helper.trace_block( with self._tracer_helper.trace_block(
"create", "create",
@ -229,6 +223,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
sender=sender, sender=sender,
metadata=get_telemetry_envelope_metadata(), metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
) )
) )
@ -259,15 +254,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if message_id is None: if message_id is None:
message_id = str(uuid.uuid4()) message_id = str(uuid.uuid4())
# event_logger.info( event_logger.info(
# MessageEvent( MessageEvent(
# payload=message, payload=self._try_serialize(message),
# sender=sender, sender=sender,
# receiver=None, receiver=topic_id,
# kind=MessageKind.PUBLISH, kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND, delivery_stage=DeliveryStage.SEND,
# ) )
# ) )
await self._message_queue.put( await self._message_queue.put(
PublishMessageEnvelope( PublishMessageEnvelope(
@ -295,32 +290,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
recipient = message_envelope.recipient recipient = message_envelope.recipient
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents if recipient.type not in self._known_agent_names:
raise LookupError(f"Agent type '{recipient.type}' does not exist.")
try: try:
# TODO use id sender_id = str(message_envelope.sender) if message_envelope.sender is not None else "Unknown"
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
logger.info( logger.info(
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}" f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_id}"
)
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=recipient,
kind=MessageKind.DIRECT,
delivery_stage=DeliveryStage.DELIVER,
)
) )
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
recipient_agent = await self._get_agent(recipient) recipient_agent = await self._get_agent(recipient)
message_context = MessageContext( message_context = MessageContext(
sender=message_envelope.sender, sender=message_envelope.sender,
topic_id=None, topic_id=None,
is_rpc=True, is_rpc=True,
cancellation_token=message_envelope.cancellation_token, cancellation_token=message_envelope.cancellation_token,
# Will be fixed when send API removed message_id=message_envelope.message_id,
message_id="NOT_DEFINED_TODO_FIX",
) )
with MessageHandlerContext.populate_context(recipient_agent.id): with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message( response = await recipient_agent.on_message(
@ -331,12 +325,36 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if not message_envelope.future.cancelled(): if not message_envelope.future.cancelled():
message_envelope.future.set_exception(e) message_envelope.future.set_exception(e)
self._message_queue.task_done() self._message_queue.task_done()
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=recipient,
exception=e,
)
)
return return
except BaseException as e: except BaseException as e:
message_envelope.future.set_exception(e) message_envelope.future.set_exception(e)
self._message_queue.task_done() self._message_queue.task_done()
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=recipient,
exception=e,
)
)
return return
event_logger.info(
MessageEvent(
payload=self._try_serialize(response),
sender=message_envelope.recipient,
receiver=message_envelope.sender,
kind=MessageKind.RESPOND,
delivery_stage=DeliveryStage.SEND,
)
)
await self._message_queue.put( await self._message_queue.put(
ResponseMessageEnvelope( ResponseMessageEnvelope(
message=response, message=response,
@ -365,15 +383,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
logger.info( logger.info(
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}" f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
) )
# event_logger.info( event_logger.info(
# MessageEvent( MessageEvent(
# payload=message_envelope.message, payload=self._try_serialize(message_envelope.message),
# sender=message_envelope.sender, sender=message_envelope.sender,
# receiver=agent, receiver=None,
# kind=MessageKind.PUBLISH, kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER, delivery_stage=DeliveryStage.DELIVER,
# ) )
# ) )
message_context = MessageContext( message_context = MessageContext(
sender=message_envelope.sender, sender=message_envelope.sender,
topic_id=message_envelope.topic_id, topic_id=message_envelope.topic_id,
@ -386,20 +404,29 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def _on_message(agent: Agent, message_context: MessageContext) -> Any: async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
with self._tracer_helper.trace_block("process", agent.id, parent=None): with self._tracer_helper.trace_block("process", agent.id, parent=None):
with MessageHandlerContext.populate_context(agent.id): with MessageHandlerContext.populate_context(agent.id):
return await agent.on_message( try:
message_envelope.message, return await agent.on_message(
ctx=message_context, message_envelope.message,
) ctx=message_context,
)
except BaseException as e:
logger.error(f"Error processing publish message for {agent.id}", exc_info=True)
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=agent.id,
exception=e,
)
)
raise
future = _on_message(agent, message_context) future = _on_message(agent, message_context)
responses.append(future) responses.append(future)
await asyncio.gather(*responses) await asyncio.gather(*responses)
except BaseException as e: except BaseException:
# Ignore cancelled errors from logs # Ignore exceptions raised during publishing. We've already logged them above.
if isinstance(e, CancelledError): pass
return
logger.error("Error processing publish message", exc_info=True)
finally: finally:
self._message_queue.task_done() self._message_queue.task_done()
# TODO if responses are given for a publish # TODO if responses are given for a publish
@ -414,18 +441,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
logger.info( logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}" f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
) )
# event_logger.info( event_logger.info(
# MessageEvent( MessageEvent(
# payload=message_envelope.message, payload=self._try_serialize(message_envelope.message),
# sender=message_envelope.sender, sender=message_envelope.sender,
# receiver=message_envelope.recipient, receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND, kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER, delivery_stage=DeliveryStage.DELIVER,
# ) )
# ) )
self._message_queue.task_done()
if not message_envelope.future.cancelled(): if not message_envelope.future.cancelled():
message_envelope.future.set_result(message_envelope.message) message_envelope.future.set_result(message_envelope.message)
self._message_queue.task_done()
@deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.") @deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.")
async def process_next(self) -> None: async def process_next(self) -> None:
@ -453,6 +480,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future.set_exception(e) future.set_exception(e)
return return
if temp_message is DropMessage or isinstance(temp_message, DropMessage): if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.DIRECT,
)
)
future.set_exception(MessageDroppedException()) future.set_exception(MessageDroppedException())
return return
@ -463,6 +498,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
case PublishMessageEnvelope( case PublishMessageEnvelope(
message=message, message=message,
sender=sender, sender=sender,
topic_id=topic_id,
): ):
if self._intervention_handlers is not None: if self._intervention_handlers is not None:
for handler in self._intervention_handlers: for handler in self._intervention_handlers:
@ -477,7 +513,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
return return
if temp_message is DropMessage or isinstance(temp_message, DropMessage): if temp_message is DropMessage or isinstance(temp_message, DropMessage):
# TODO log message dropped event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=topic_id,
kind=MessageKind.PUBLISH,
)
)
return return
message_envelope.message = temp_message message_envelope.message = temp_message
@ -495,6 +538,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future.set_exception(e) future.set_exception(e)
return return
if temp_message is DropMessage or isinstance(temp_message, DropMessage): if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.RESPOND,
)
)
future.set_exception(MessageDroppedException()) future.set_exception(MessageDroppedException())
return return
message_envelope.message = temp_message message_envelope.message = temp_message
@ -603,23 +654,34 @@ class SingleThreadedAgentRuntime(AgentRuntime):
agent_id: AgentId, agent_id: AgentId,
) -> T: ) -> T:
with AgentInstantiationContext.populate_context((self, agent_id)): with AgentInstantiationContext.populate_context((self, agent_id)):
if len(inspect.signature(agent_factory).parameters) == 0: try:
factory_one = cast(Callable[[], T], agent_factory) if len(inspect.signature(agent_factory).parameters) == 0:
agent = factory_one() factory_one = cast(Callable[[], T], agent_factory)
elif len(inspect.signature(agent_factory).parameters) == 2: agent = factory_one()
warnings.warn( elif len(inspect.signature(agent_factory).parameters) == 2:
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.", warnings.warn(
stacklevel=2, "Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
stacklevel=2,
)
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
agent = factory_two(self, agent_id)
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
return agent
except BaseException as e:
event_logger.info(
AgentConstructionExceptionEvent(
agent_id=agent_id,
exception=e,
)
) )
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) logger.error(f"Error constructing agent {agent_id}", exc_info=True)
agent = factory_two(self, agent_id) raise
else:
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
return agent
async def _get_agent(self, agent_id: AgentId) -> Agent: async def _get_agent(self, agent_id: AgentId) -> Agent:
if agent_id in self._instantiated_agents: if agent_id in self._instantiated_agents:
@ -666,3 +728,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer) self._serialization_registry.add_serializer(serializer)
def _try_serialize(self, message: Any) -> str:
try:
type_name = self._serialization_registry.type_name(message)
return self._serialization_registry.serialize(
message, type_name=type_name, data_content_type=JSON_DATA_CONTENT_TYPE
).decode("utf-8")
except ValueError:
return "Message could not be serialized"

View File

@ -2,7 +2,8 @@ import json
from enum import Enum from enum import Enum
from typing import Any, cast from typing import Any, cast
from autogen_core import AgentId from ._agent_id import AgentId
from ._topic import TopicId
class LLMCallEvent: class LLMCallEvent:
@ -57,9 +58,9 @@ class MessageEvent:
def __init__( def __init__(
self, self,
*, *,
payload: Any, payload: str,
sender: AgentId | None, sender: AgentId | None,
receiver: AgentId | None, receiver: AgentId | TopicId | None,
kind: MessageKind, kind: MessageKind,
delivery_stage: DeliveryStage, delivery_stage: DeliveryStage,
**kwargs: Any, **kwargs: Any,
@ -68,18 +69,70 @@ class MessageEvent:
self.kwargs["payload"] = payload self.kwargs["payload"] = payload
self.kwargs["sender"] = None if sender is None else str(sender) self.kwargs["sender"] = None if sender is None else str(sender)
self.kwargs["receiver"] = None if receiver is None else str(receiver) self.kwargs["receiver"] = None if receiver is None else str(receiver)
self.kwargs["kind"] = kind self.kwargs["kind"] = str(kind)
self.kwargs["delivery_stage"] = delivery_stage self.kwargs["delivery_stage"] = str(delivery_stage)
self.kwargs["type"] = "Message" self.kwargs["type"] = "Message"
@property # This must output the event in a json serializable format
def prompt_tokens(self) -> int: def __str__(self) -> str:
return cast(int, self.kwargs["prompt_tokens"]) return json.dumps(self.kwargs)
@property
def completion_tokens(self) -> int: class MessageDroppedEvent:
return cast(int, self.kwargs["completion_tokens"]) def __init__(
self,
*,
payload: str,
sender: AgentId | None,
receiver: AgentId | TopicId | None,
kind: MessageKind,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["payload"] = payload
self.kwargs["sender"] = None if sender is None else str(sender)
self.kwargs["receiver"] = None if receiver is None else str(receiver)
self.kwargs["kind"] = str(kind)
self.kwargs["type"] = "MessageDropped"
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class MessageHandlerExceptionEvent:
def __init__(
self,
*,
payload: str,
handling_agent: AgentId,
exception: BaseException,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["payload"] = payload
self.kwargs["handling_agent"] = str(handling_agent)
self.kwargs["exception"] = str(exception)
self.kwargs["type"] = "MessageHandlerException"
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class AgentConstructionExceptionEvent:
def __init__(
self,
*,
agent_id: AgentId,
exception: BaseException,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["agent_id"] = str(agent_id)
self.kwargs["exception"] = str(exception)
self.kwargs["type"] = "AgentConstructionException"
# This must output the event in a json serializable format # This must output the event in a json serializable format
def __str__(self) -> str: def __str__(self) -> str:
return json.dumps(self.kwargs) return json.dumps(self.kwargs)

View File

@ -78,7 +78,7 @@ async def test_message_handler_router() -> None:
@dataclass @dataclass
class TestMessage: class MyMessage:
value: str value: str
@ -89,15 +89,15 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
self.handler_two_called = False self.handler_two_called = False
@staticmethod @staticmethod
def match_one(message: TestMessage, ctx: MessageContext) -> bool: def match_one(message: MyMessage, ctx: MessageContext) -> bool:
return message.value == "one" return message.value == "one"
@message_handler(match=match_one) @message_handler(match=match_one)
async def handler_one(self, message: TestMessage, ctx: MessageContext) -> None: async def handler_one(self, message: MyMessage, ctx: MessageContext) -> None:
self.handler_one_called = True self.handler_one_called = True
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore @message_handler(match=cast(Callable[[MyMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None: async def handler_two(self, message: MyMessage, ctx: MessageContext) -> None:
self.handler_two_called = True self.handler_two_called = True
@ -113,14 +113,14 @@ async def test_routed_agent_message_matching() -> None:
assert agent.handler_two_called is False assert agent.handler_two_called is False
runtime.start() runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id) await runtime.send_message(MyMessage("one"), recipient=agent_id)
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
assert agent.handler_one_called is True assert agent.handler_one_called is True
assert agent.handler_two_called is False assert agent.handler_two_called is False
runtime.start() runtime.start()
await runtime.send_message(TestMessage("two"), recipient=agent_id) await runtime.send_message(MyMessage("two"), recipient=agent_id)
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
assert agent.handler_one_called is True assert agent.handler_one_called is True
@ -133,11 +133,11 @@ class EventAgent(RoutedAgent):
self.num_calls = [0, 0] self.num_calls = [0, 0]
@event(match=lambda msg, ctx: msg.value == "one") # type: ignore @event(match=lambda msg, ctx: msg.value == "one") # type: ignore
async def on_event_one(self, message: TestMessage, ctx: MessageContext) -> None: async def on_event_one(self, message: MyMessage, ctx: MessageContext) -> None:
self.num_calls[0] += 1 self.num_calls[0] += 1
@event(match=lambda msg, ctx: msg.value == "two") # type: ignore @event(match=lambda msg, ctx: msg.value == "two") # type: ignore
async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None: async def on_event_two(self, message: MyMessage, ctx: MessageContext) -> None:
self.num_calls[1] += 1 self.num_calls[1] += 1
@ -150,7 +150,7 @@ async def test_event() -> None:
# Send a broadcast message. # Send a broadcast message.
runtime.start() runtime.start()
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default")) await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1 assert agent.num_calls[0] == 1
@ -158,7 +158,7 @@ async def test_event() -> None:
# Send another broadcast message. # Send another broadcast message.
runtime.start() runtime.start()
await runtime.publish_message(TestMessage("two"), topic_id=TopicId("default", "default")) await runtime.publish_message(MyMessage("two"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1 assert agent.num_calls[0] == 1
@ -166,7 +166,7 @@ async def test_event() -> None:
# Send an RPC message, expect no change. # Send an RPC message, expect no change.
runtime.start() runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id) await runtime.send_message(MyMessage("one"), recipient=agent_id)
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1 assert agent.num_calls[0] == 1
@ -179,12 +179,12 @@ class RPCAgent(RoutedAgent):
self.num_calls = [0, 0] self.num_calls = [0, 0]
@rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore @rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore
async def on_rpc_one(self, message: TestMessage, ctx: MessageContext) -> TestMessage: async def on_rpc_one(self, message: MyMessage, ctx: MessageContext) -> MyMessage:
self.num_calls[0] += 1 self.num_calls[0] += 1
return message return message
@rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore @rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore
async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMessage: async def on_rpc_two(self, message: MyMessage, ctx: MessageContext) -> MyMessage:
self.num_calls[1] += 1 self.num_calls[1] += 1
return message return message
@ -198,7 +198,7 @@ async def test_rpc() -> None:
# Send an RPC message. # Send an RPC message.
runtime.start() runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id) await runtime.send_message(MyMessage("one"), recipient=agent_id)
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1 assert agent.num_calls[0] == 1
@ -206,7 +206,7 @@ async def test_rpc() -> None:
# Send another RPC message. # Send another RPC message.
runtime.start() runtime.start()
await runtime.send_message(TestMessage("two"), recipient=agent_id) await runtime.send_message(MyMessage("two"), recipient=agent_id)
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1 assert agent.num_calls[0] == 1
@ -214,7 +214,7 @@ async def test_rpc() -> None:
# Send a broadcast message, expect no change. # Send a broadcast message, expect no change.
runtime.start() runtime.start()
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default")) await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle() await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1 assert agent.num_calls[0] == 1

View File

@ -20,10 +20,10 @@ from autogen_test_utils import (
MessageType, MessageType,
NoopAgent, NoopAgent,
) )
from autogen_test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
test_exporter = TestExporter() test_exporter = MyTestExporter()
@pytest.fixture @pytest.fixture
@ -88,7 +88,7 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptureFixture) -> None: async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None:
runtime = SingleThreadedAgentRuntime() runtime = SingleThreadedAgentRuntime()
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType)) runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
@ -103,8 +103,9 @@ async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptur
runtime.start() runtime.start()
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default")) await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle() await runtime.stop_when_idle()
# Check if logger has the exception.
assert any("Error processing publish message" in e.message for e in caplog.records) # Check if logger has the exception.
assert any("Error constructing agent" in e.message for e in caplog.records)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -339,7 +339,9 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
*, *,
sender: AgentId | None = None, sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None, cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ) -> Any:
# TODO: use message_id
if not self._running: if not self._running:
raise ValueError("Runtime must be running when sending message.") raise ValueError("Runtime must be running when sending message.")
if self._host_connection is None: if self._host_connection is None:

View File

@ -7,8 +7,10 @@ name = "autogen-test-utils"
version = "0.0.0" version = "0.0.0"
license = {file = "LICENSE-CODE"} license = {file = "LICENSE-CODE"}
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = ["autogen-core", dependencies = [
"autogen-core",
"pytest",
"opentelemetry-sdk>=1.27.0",
] ]
[tool.ruff] [tool.ruff]

View File

@ -1,10 +1,11 @@
from typing import List, Sequence from typing import List, Sequence
import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult
class TestExporter(SpanExporter): class MyTestExporter(SpanExporter):
def __init__(self) -> None: def __init__(self) -> None:
self.exported_spans: List[ReadableSpan] = [] self.exported_spans: List[ReadableSpan] = []
@ -24,7 +25,7 @@ class TestExporter(SpanExporter):
return self.exported_spans return self.exported_spans
def get_test_tracer_provider(exporter: TestExporter) -> TracerProvider: def get_test_tracer_provider(exporter: MyTestExporter) -> TracerProvider:
tracer_provider = TracerProvider() tracer_provider = TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
return tracer_provider return tracer_provider

8
python/uv.lock generated
View File

@ -627,10 +627,16 @@ version = "0.0.0"
source = { editable = "packages/autogen-test-utils" } source = { editable = "packages/autogen-test-utils" }
dependencies = [ dependencies = [
{ name = "autogen-core" }, { name = "autogen-core" },
{ name = "opentelemetry-sdk" },
{ name = "pytest" },
] ]
[package.metadata] [package.metadata]
requires-dist = [{ name = "autogen-core", editable = "packages/autogen-core" }] requires-dist = [
{ name = "autogen-core", editable = "packages/autogen-core" },
{ name = "opentelemetry-sdk", specifier = ">=1.27.0" },
{ name = "pytest" },
]
[[package]] [[package]]
name = "autogenstudio" name = "autogenstudio"