mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-12 00:20:50 +00:00
Update logged events, add message id to send message (#4868)
* Update logged events * add message_id * serialize payload for log * fix pytest warning * serialization * fix test * lock * fix warning and test
This commit is contained in:
parent
d2a74de3ad
commit
e6ac2f37fa
@ -33,12 +33,14 @@ class AgentProxy:
|
||||
*,
|
||||
sender: AgentId,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
return await self._runtime.send_message(
|
||||
message,
|
||||
recipient=self._agent,
|
||||
sender=sender,
|
||||
cancellation_token=cancellation_token,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
|
@ -26,6 +26,7 @@ class AgentRuntime(Protocol):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""Send a message to an agent and get a response.
|
||||
|
||||
|
@ -121,6 +121,7 @@ class BaseAgent(ABC, Agent):
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""See :py:meth:`autogen_core.AgentRuntime.send_message` for more information."""
|
||||
if cancellation_token is None:
|
||||
@ -131,6 +132,7 @@ class BaseAgent(ABC, Agent):
|
||||
sender=self.id,
|
||||
recipient=recipient,
|
||||
cancellation_token=cancellation_token,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish_message(
|
||||
|
@ -61,6 +61,7 @@ class ClosureContext(Protocol):
|
||||
recipient: AgentId,
|
||||
*,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
async def publish_message(
|
||||
|
@ -8,7 +8,7 @@ from ._agent_id import AgentId
|
||||
class MessageHandlerContext:
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
|
||||
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling."
|
||||
)
|
||||
|
||||
_MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT")
|
||||
|
@ -4,7 +4,6 @@ import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Queue, Task
|
||||
@ -14,6 +13,15 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set
|
||||
|
||||
from opentelemetry.trace import TracerProvider
|
||||
|
||||
from .logging import (
|
||||
AgentConstructionExceptionEvent,
|
||||
DeliveryStage,
|
||||
MessageDroppedEvent,
|
||||
MessageEvent,
|
||||
MessageHandlerExceptionEvent,
|
||||
MessageKind,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from asyncio import Queue, QueueShutDown
|
||||
else:
|
||||
@ -32,7 +40,7 @@ from ._intervention import DropMessage, InterventionHandler
|
||||
from ._message_context import MessageContext
|
||||
from ._message_handler_context import MessageHandlerContext
|
||||
from ._runtime_impl_helpers import SubscriptionManager, get_impl
|
||||
from ._serialization import MessageSerializer, SerializationRegistry
|
||||
from ._serialization import JSON_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry
|
||||
from ._subscription import Subscription
|
||||
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
|
||||
from ._topic import TopicId
|
||||
@ -70,6 +78,7 @@ class SendMessageEnvelope:
|
||||
future: Future[Any]
|
||||
cancellation_token: CancellationToken
|
||||
metadata: EnvelopeMetadata | None = None
|
||||
message_id: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -87,25 +96,6 @@ P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=Agent)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self) -> None:
|
||||
self._count: int = 0
|
||||
self.threadLock = threading.Lock()
|
||||
|
||||
def increment(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count += 1
|
||||
self.threadLock.release()
|
||||
|
||||
def get(self) -> int:
|
||||
return self._count
|
||||
|
||||
def decrement(self) -> None:
|
||||
self.threadLock.acquire()
|
||||
self._count -= 1
|
||||
self.threadLock.release()
|
||||
|
||||
|
||||
class RunContext:
|
||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
@ -194,19 +184,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
if cancellation_token is None:
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
event_logger.info(
|
||||
MessageEvent(
|
||||
payload=self._try_serialize(message),
|
||||
sender=sender,
|
||||
receiver=recipient,
|
||||
kind=MessageKind.DIRECT,
|
||||
delivery_stage=DeliveryStage.SEND,
|
||||
)
|
||||
)
|
||||
|
||||
with self._tracer_helper.trace_block(
|
||||
"create",
|
||||
@ -229,6 +223,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
cancellation_token=cancellation_token,
|
||||
sender=sender,
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
@ -259,15 +254,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message,
|
||||
# sender=sender,
|
||||
# receiver=None,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.SEND,
|
||||
# )
|
||||
# )
|
||||
event_logger.info(
|
||||
MessageEvent(
|
||||
payload=self._try_serialize(message),
|
||||
sender=sender,
|
||||
receiver=topic_id,
|
||||
kind=MessageKind.PUBLISH,
|
||||
delivery_stage=DeliveryStage.SEND,
|
||||
)
|
||||
)
|
||||
|
||||
await self._message_queue.put(
|
||||
PublishMessageEnvelope(
|
||||
@ -295,32 +290,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
|
||||
recipient = message_envelope.recipient
|
||||
# todo: check if recipient is in the known namespaces
|
||||
# assert recipient in self._agents
|
||||
|
||||
if recipient.type not in self._known_agent_names:
|
||||
raise LookupError(f"Agent type '{recipient.type}' does not exist.")
|
||||
|
||||
try:
|
||||
# TODO use id
|
||||
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
|
||||
sender_id = str(message_envelope.sender) if message_envelope.sender is not None else "Unknown"
|
||||
logger.info(
|
||||
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
|
||||
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_id}"
|
||||
)
|
||||
event_logger.info(
|
||||
MessageEvent(
|
||||
payload=self._try_serialize(message_envelope.message),
|
||||
sender=message_envelope.sender,
|
||||
receiver=recipient,
|
||||
kind=MessageKind.DIRECT,
|
||||
delivery_stage=DeliveryStage.DELIVER,
|
||||
)
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=recipient,
|
||||
# kind=MessageKind.DIRECT,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
recipient_agent = await self._get_agent(recipient)
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=None,
|
||||
is_rpc=True,
|
||||
cancellation_token=message_envelope.cancellation_token,
|
||||
# Will be fixed when send API removed
|
||||
message_id="NOT_DEFINED_TODO_FIX",
|
||||
message_id=message_envelope.message_id,
|
||||
)
|
||||
with MessageHandlerContext.populate_context(recipient_agent.id):
|
||||
response = await recipient_agent.on_message(
|
||||
@ -331,12 +325,36 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_exception(e)
|
||||
self._message_queue.task_done()
|
||||
event_logger.info(
|
||||
MessageHandlerExceptionEvent(
|
||||
payload=self._try_serialize(message_envelope.message),
|
||||
handling_agent=recipient,
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
return
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
self._message_queue.task_done()
|
||||
event_logger.info(
|
||||
MessageHandlerExceptionEvent(
|
||||
payload=self._try_serialize(message_envelope.message),
|
||||
handling_agent=recipient,
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
event_logger.info(
|
||||
MessageEvent(
|
||||
payload=self._try_serialize(response),
|
||||
sender=message_envelope.recipient,
|
||||
receiver=message_envelope.sender,
|
||||
kind=MessageKind.RESPOND,
|
||||
delivery_stage=DeliveryStage.SEND,
|
||||
)
|
||||
)
|
||||
|
||||
await self._message_queue.put(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
@ -365,15 +383,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
logger.info(
|
||||
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=agent,
|
||||
# kind=MessageKind.PUBLISH,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
event_logger.info(
|
||||
MessageEvent(
|
||||
payload=self._try_serialize(message_envelope.message),
|
||||
sender=message_envelope.sender,
|
||||
receiver=None,
|
||||
kind=MessageKind.PUBLISH,
|
||||
delivery_stage=DeliveryStage.DELIVER,
|
||||
)
|
||||
)
|
||||
message_context = MessageContext(
|
||||
sender=message_envelope.sender,
|
||||
topic_id=message_envelope.topic_id,
|
||||
@ -386,20 +404,29 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
|
||||
with self._tracer_helper.trace_block("process", agent.id, parent=None):
|
||||
with MessageHandlerContext.populate_context(agent.id):
|
||||
return await agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
try:
|
||||
return await agent.on_message(
|
||||
message_envelope.message,
|
||||
ctx=message_context,
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"Error processing publish message for {agent.id}", exc_info=True)
|
||||
event_logger.info(
|
||||
MessageHandlerExceptionEvent(
|
||||
payload=self._try_serialize(message_envelope.message),
|
||||
handling_agent=agent.id,
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
future = _on_message(agent, message_context)
|
||||
responses.append(future)
|
||||
|
||||
await asyncio.gather(*responses)
|
||||
except BaseException as e:
|
||||
# Ignore cancelled errors from logs
|
||||
if isinstance(e, CancelledError):
|
||||
return
|
||||
logger.error("Error processing publish message", exc_info=True)
|
||||
except BaseException:
|
||||
# Ignore exceptions raised during publishing. We've already logged them above.
|
||||
pass
|
||||
finally:
|
||||
self._message_queue.task_done()
|
||||
# TODO if responses are given for a publish
|
||||
@ -414,18 +441,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
logger.info(
|
||||
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
|
||||
)
|
||||
# event_logger.info(
|
||||
# MessageEvent(
|
||||
# payload=message_envelope.message,
|
||||
# sender=message_envelope.sender,
|
||||
# receiver=message_envelope.recipient,
|
||||
# kind=MessageKind.RESPOND,
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
self._message_queue.task_done()
|
||||
event_logger.info(
|
||||
MessageEvent(
|
||||
payload=self._try_serialize(message_envelope.message),
|
||||
sender=message_envelope.sender,
|
||||
receiver=message_envelope.recipient,
|
||||
kind=MessageKind.RESPOND,
|
||||
delivery_stage=DeliveryStage.DELIVER,
|
||||
)
|
||||
)
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
self._message_queue.task_done()
|
||||
|
||||
@deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.")
|
||||
async def process_next(self) -> None:
|
||||
@ -453,6 +480,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
event_logger.info(
|
||||
MessageDroppedEvent(
|
||||
payload=self._try_serialize(message),
|
||||
sender=sender,
|
||||
receiver=recipient,
|
||||
kind=MessageKind.DIRECT,
|
||||
)
|
||||
)
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
|
||||
@ -463,6 +498,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
case PublishMessageEnvelope(
|
||||
message=message,
|
||||
sender=sender,
|
||||
topic_id=topic_id,
|
||||
):
|
||||
if self._intervention_handlers is not None:
|
||||
for handler in self._intervention_handlers:
|
||||
@ -477,7 +513,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
# TODO log message dropped
|
||||
event_logger.info(
|
||||
MessageDroppedEvent(
|
||||
payload=self._try_serialize(message),
|
||||
sender=sender,
|
||||
receiver=topic_id,
|
||||
kind=MessageKind.PUBLISH,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
@ -495,6 +538,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
future.set_exception(e)
|
||||
return
|
||||
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
|
||||
event_logger.info(
|
||||
MessageDroppedEvent(
|
||||
payload=self._try_serialize(message),
|
||||
sender=sender,
|
||||
receiver=recipient,
|
||||
kind=MessageKind.RESPOND,
|
||||
)
|
||||
)
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
message_envelope.message = temp_message
|
||||
@ -603,23 +654,34 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
agent_id: AgentId,
|
||||
) -> T:
|
||||
with AgentInstantiationContext.populate_context((self, agent_id)):
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
try:
|
||||
if len(inspect.signature(agent_factory).parameters) == 0:
|
||||
factory_one = cast(Callable[[], T], agent_factory)
|
||||
agent = factory_one()
|
||||
elif len(inspect.signature(agent_factory).parameters) == 2:
|
||||
warnings.warn(
|
||||
"Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version.",
|
||||
stacklevel=2,
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
return agent
|
||||
|
||||
except BaseException as e:
|
||||
event_logger.info(
|
||||
AgentConstructionExceptionEvent(
|
||||
agent_id=agent_id,
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory)
|
||||
agent = factory_two(self, agent_id)
|
||||
else:
|
||||
raise ValueError("Agent factory must take 0 or 2 arguments.")
|
||||
|
||||
if inspect.isawaitable(agent):
|
||||
return cast(T, await agent)
|
||||
|
||||
return agent
|
||||
logger.error(f"Error constructing agent {agent_id}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _get_agent(self, agent_id: AgentId) -> Agent:
|
||||
if agent_id in self._instantiated_agents:
|
||||
@ -666,3 +728,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
|
||||
self._serialization_registry.add_serializer(serializer)
|
||||
|
||||
def _try_serialize(self, message: Any) -> str:
|
||||
try:
|
||||
type_name = self._serialization_registry.type_name(message)
|
||||
return self._serialization_registry.serialize(
|
||||
message, type_name=type_name, data_content_type=JSON_DATA_CONTENT_TYPE
|
||||
).decode("utf-8")
|
||||
except ValueError:
|
||||
return "Message could not be serialized"
|
||||
|
@ -2,7 +2,8 @@ import json
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
from autogen_core import AgentId
|
||||
from ._agent_id import AgentId
|
||||
from ._topic import TopicId
|
||||
|
||||
|
||||
class LLMCallEvent:
|
||||
@ -57,9 +58,9 @@ class MessageEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: Any,
|
||||
payload: str,
|
||||
sender: AgentId | None,
|
||||
receiver: AgentId | None,
|
||||
receiver: AgentId | TopicId | None,
|
||||
kind: MessageKind,
|
||||
delivery_stage: DeliveryStage,
|
||||
**kwargs: Any,
|
||||
@ -68,18 +69,70 @@ class MessageEvent:
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||
self.kwargs["kind"] = kind
|
||||
self.kwargs["delivery_stage"] = delivery_stage
|
||||
self.kwargs["kind"] = str(kind)
|
||||
self.kwargs["delivery_stage"] = str(delivery_stage)
|
||||
self.kwargs["type"] = "Message"
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["prompt_tokens"])
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return cast(int, self.kwargs["completion_tokens"])
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageDroppedEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
sender: AgentId | None,
|
||||
receiver: AgentId | TopicId | None,
|
||||
kind: MessageKind,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["sender"] = None if sender is None else str(sender)
|
||||
self.kwargs["receiver"] = None if receiver is None else str(receiver)
|
||||
self.kwargs["kind"] = str(kind)
|
||||
self.kwargs["type"] = "MessageDropped"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class MessageHandlerExceptionEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
payload: str,
|
||||
handling_agent: AgentId,
|
||||
exception: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["payload"] = payload
|
||||
self.kwargs["handling_agent"] = str(handling_agent)
|
||||
self.kwargs["exception"] = str(exception)
|
||||
self.kwargs["type"] = "MessageHandlerException"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
||||
|
||||
class AgentConstructionExceptionEvent:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent_id: AgentId,
|
||||
exception: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.kwargs["agent_id"] = str(agent_id)
|
||||
self.kwargs["exception"] = str(exception)
|
||||
self.kwargs["type"] = "AgentConstructionException"
|
||||
|
||||
# This must output the event in a json serializable format
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.kwargs)
|
||||
|
@ -78,7 +78,7 @@ async def test_message_handler_router() -> None:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMessage:
|
||||
class MyMessage:
|
||||
value: str
|
||||
|
||||
|
||||
@ -89,15 +89,15 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
|
||||
self.handler_two_called = False
|
||||
|
||||
@staticmethod
|
||||
def match_one(message: TestMessage, ctx: MessageContext) -> bool:
|
||||
def match_one(message: MyMessage, ctx: MessageContext) -> bool:
|
||||
return message.value == "one"
|
||||
|
||||
@message_handler(match=match_one)
|
||||
async def handler_one(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
async def handler_one(self, message: MyMessage, ctx: MessageContext) -> None:
|
||||
self.handler_one_called = True
|
||||
|
||||
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
|
||||
async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
@message_handler(match=cast(Callable[[MyMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
|
||||
async def handler_two(self, message: MyMessage, ctx: MessageContext) -> None:
|
||||
self.handler_two_called = True
|
||||
|
||||
|
||||
@ -113,14 +113,14 @@ async def test_routed_agent_message_matching() -> None:
|
||||
assert agent.handler_two_called is False
|
||||
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("one"), recipient=agent_id)
|
||||
await runtime.send_message(MyMessage("one"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
|
||||
assert agent.handler_one_called is True
|
||||
assert agent.handler_two_called is False
|
||||
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("two"), recipient=agent_id)
|
||||
await runtime.send_message(MyMessage("two"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
|
||||
assert agent.handler_one_called is True
|
||||
@ -133,11 +133,11 @@ class EventAgent(RoutedAgent):
|
||||
self.num_calls = [0, 0]
|
||||
|
||||
@event(match=lambda msg, ctx: msg.value == "one") # type: ignore
|
||||
async def on_event_one(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
async def on_event_one(self, message: MyMessage, ctx: MessageContext) -> None:
|
||||
self.num_calls[0] += 1
|
||||
|
||||
@event(match=lambda msg, ctx: msg.value == "two") # type: ignore
|
||||
async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
async def on_event_two(self, message: MyMessage, ctx: MessageContext) -> None:
|
||||
self.num_calls[1] += 1
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ async def test_event() -> None:
|
||||
|
||||
# Send a broadcast message.
|
||||
runtime.start()
|
||||
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
|
||||
await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
@ -158,7 +158,7 @@ async def test_event() -> None:
|
||||
|
||||
# Send another broadcast message.
|
||||
runtime.start()
|
||||
await runtime.publish_message(TestMessage("two"), topic_id=TopicId("default", "default"))
|
||||
await runtime.publish_message(MyMessage("two"), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
@ -166,7 +166,7 @@ async def test_event() -> None:
|
||||
|
||||
# Send an RPC message, expect no change.
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("one"), recipient=agent_id)
|
||||
await runtime.send_message(MyMessage("one"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
@ -179,12 +179,12 @@ class RPCAgent(RoutedAgent):
|
||||
self.num_calls = [0, 0]
|
||||
|
||||
@rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore
|
||||
async def on_rpc_one(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
|
||||
async def on_rpc_one(self, message: MyMessage, ctx: MessageContext) -> MyMessage:
|
||||
self.num_calls[0] += 1
|
||||
return message
|
||||
|
||||
@rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore
|
||||
async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
|
||||
async def on_rpc_two(self, message: MyMessage, ctx: MessageContext) -> MyMessage:
|
||||
self.num_calls[1] += 1
|
||||
return message
|
||||
|
||||
@ -198,7 +198,7 @@ async def test_rpc() -> None:
|
||||
|
||||
# Send an RPC message.
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("one"), recipient=agent_id)
|
||||
await runtime.send_message(MyMessage("one"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
@ -206,7 +206,7 @@ async def test_rpc() -> None:
|
||||
|
||||
# Send another RPC message.
|
||||
runtime.start()
|
||||
await runtime.send_message(TestMessage("two"), recipient=agent_id)
|
||||
await runtime.send_message(MyMessage("two"), recipient=agent_id)
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
@ -214,7 +214,7 @@ async def test_rpc() -> None:
|
||||
|
||||
# Send a broadcast message, expect no change.
|
||||
runtime.start()
|
||||
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
|
||||
await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
|
||||
assert agent.num_calls[0] == 1
|
||||
|
@ -20,10 +20,10 @@ from autogen_test_utils import (
|
||||
MessageType,
|
||||
NoopAgent,
|
||||
)
|
||||
from autogen_test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
|
||||
from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
test_exporter = TestExporter()
|
||||
test_exporter = MyTestExporter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -88,7 +88,7 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptureFixture) -> None:
|
||||
async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
|
||||
@ -103,8 +103,9 @@ async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptur
|
||||
runtime.start()
|
||||
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
|
||||
await runtime.stop_when_idle()
|
||||
# Check if logger has the exception.
|
||||
assert any("Error processing publish message" in e.message for e in caplog.records)
|
||||
|
||||
# Check if logger has the exception.
|
||||
assert any("Error constructing agent" in e.message for e in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -339,7 +339,9 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
*,
|
||||
sender: AgentId | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
# TODO: use message_id
|
||||
if not self._running:
|
||||
raise ValueError("Runtime must be running when sending message.")
|
||||
if self._host_connection is None:
|
||||
|
@ -7,8 +7,10 @@ name = "autogen-test-utils"
|
||||
version = "0.0.0"
|
||||
license = {file = "LICENSE-CODE"}
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["autogen-core",
|
||||
|
||||
dependencies = [
|
||||
"autogen-core",
|
||||
"pytest",
|
||||
"opentelemetry-sdk>=1.27.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult
|
||||
|
||||
|
||||
class TestExporter(SpanExporter):
|
||||
class MyTestExporter(SpanExporter):
|
||||
def __init__(self) -> None:
|
||||
self.exported_spans: List[ReadableSpan] = []
|
||||
|
||||
@ -24,7 +25,7 @@ class TestExporter(SpanExporter):
|
||||
return self.exported_spans
|
||||
|
||||
|
||||
def get_test_tracer_provider(exporter: TestExporter) -> TracerProvider:
|
||||
def get_test_tracer_provider(exporter: MyTestExporter) -> TracerProvider:
|
||||
tracer_provider = TracerProvider()
|
||||
tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
|
||||
return tracer_provider
|
||||
|
8
python/uv.lock
generated
8
python/uv.lock
generated
@ -627,10 +627,16 @@ version = "0.0.0"
|
||||
source = { editable = "packages/autogen-test-utils" }
|
||||
dependencies = [
|
||||
{ name = "autogen-core" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "autogen-core", editable = "packages/autogen-core" }]
|
||||
requires-dist = [
|
||||
{ name = "autogen-core", editable = "packages/autogen-core" },
|
||||
{ name = "opentelemetry-sdk", specifier = ">=1.27.0" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autogenstudio"
|
||||
|
Loading…
x
Reference in New Issue
Block a user