Update logged events, add message id to send message (#4868)

* Update logged events

* add message_id

* serialize payload for log

* fix pytest warning

* serialization

* fix test

* lock

* fix warning and test
This commit is contained in:
Jack Gerrits 2024-12-31 15:11:48 -05:00 committed by GitHub
parent d2a74de3ad
commit e6ac2f37fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 283 additions and 141 deletions

View File

@ -33,12 +33,14 @@ class AgentProxy:
*,
sender: AgentId,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
return await self._runtime.send_message(
message,
recipient=self._agent,
sender=sender,
cancellation_token=cancellation_token,
message_id=message_id,
)
async def save_state(self) -> Mapping[str, Any]:

View File

@ -26,6 +26,7 @@ class AgentRuntime(Protocol):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
"""Send a message to an agent and get a response.

View File

@ -121,6 +121,7 @@ class BaseAgent(ABC, Agent):
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
"""See :py:meth:`autogen_core.AgentRuntime.send_message` for more information."""
if cancellation_token is None:
@ -131,6 +132,7 @@ class BaseAgent(ABC, Agent):
sender=self.id,
recipient=recipient,
cancellation_token=cancellation_token,
message_id=message_id,
)
async def publish_message(

View File

@ -61,6 +61,7 @@ class ClosureContext(Protocol):
recipient: AgentId,
*,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any: ...
async def publish_message(

View File

@ -8,7 +8,7 @@ from ._agent_id import AgentId
class MessageHandlerContext:
def __init__(self) -> None:
raise RuntimeError(
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for message handling."
)
_MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("_MESSAGE_HANDLER_CONTEXT")

View File

@ -4,7 +4,6 @@ import asyncio
import inspect
import logging
import sys
import threading
import uuid
import warnings
from asyncio import CancelledError, Future, Queue, Task
@ -14,6 +13,15 @@ from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set
from opentelemetry.trace import TracerProvider
from .logging import (
AgentConstructionExceptionEvent,
DeliveryStage,
MessageDroppedEvent,
MessageEvent,
MessageHandlerExceptionEvent,
MessageKind,
)
if sys.version_info >= (3, 13):
from asyncio import Queue, QueueShutDown
else:
@ -32,7 +40,7 @@ from ._intervention import DropMessage, InterventionHandler
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._runtime_impl_helpers import SubscriptionManager, get_impl
from ._serialization import MessageSerializer, SerializationRegistry
from ._serialization import JSON_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry
from ._subscription import Subscription
from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata
from ._topic import TopicId
@ -70,6 +78,7 @@ class SendMessageEnvelope:
future: Future[Any]
cancellation_token: CancellationToken
metadata: EnvelopeMetadata | None = None
message_id: str
@dataclass(kw_only=True)
@ -87,25 +96,6 @@ P = ParamSpec("P")
T = TypeVar("T", bound=Agent)
class Counter:
def __init__(self) -> None:
self._count: int = 0
self.threadLock = threading.Lock()
def increment(self) -> None:
self.threadLock.acquire()
self._count += 1
self.threadLock.release()
def get(self) -> int:
return self._count
def decrement(self) -> None:
self.threadLock.acquire()
self._count -= 1
self.threadLock.release()
class RunContext:
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
@ -194,19 +184,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
if cancellation_token is None:
cancellation_token = CancellationToken()
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
if message_id is None:
message_id = str(uuid.uuid4())
event_logger.info(
MessageEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.DIRECT,
delivery_stage=DeliveryStage.SEND,
)
)
with self._tracer_helper.trace_block(
"create",
@ -229,6 +223,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
cancellation_token=cancellation_token,
sender=sender,
metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
)
)
@ -259,15 +254,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if message_id is None:
message_id = str(uuid.uuid4())
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
event_logger.info(
MessageEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=topic_id,
kind=MessageKind.PUBLISH,
delivery_stage=DeliveryStage.SEND,
)
)
await self._message_queue.put(
PublishMessageEnvelope(
@ -295,32 +290,31 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
recipient = message_envelope.recipient
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
if recipient.type not in self._known_agent_names:
raise LookupError(f"Agent type '{recipient.type}' does not exist.")
try:
# TODO use id
sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown"
sender_id = str(message_envelope.sender) if message_envelope.sender is not None else "Unknown"
logger.info(
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}"
f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_id}"
)
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=recipient,
kind=MessageKind.DIRECT,
delivery_stage=DeliveryStage.DELIVER,
)
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
recipient_agent = await self._get_agent(recipient)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
# Will be fixed when send API removed
message_id="NOT_DEFINED_TODO_FIX",
message_id=message_envelope.message_id,
)
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
@ -331,12 +325,36 @@ class SingleThreadedAgentRuntime(AgentRuntime):
if not message_envelope.future.cancelled():
message_envelope.future.set_exception(e)
self._message_queue.task_done()
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=recipient,
exception=e,
)
)
return
except BaseException as e:
message_envelope.future.set_exception(e)
self._message_queue.task_done()
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=recipient,
exception=e,
)
)
return
event_logger.info(
MessageEvent(
payload=self._try_serialize(response),
sender=message_envelope.recipient,
receiver=message_envelope.sender,
kind=MessageKind.RESPOND,
delivery_stage=DeliveryStage.SEND,
)
)
await self._message_queue.put(
ResponseMessageEnvelope(
message=response,
@ -365,15 +383,15 @@ class SingleThreadedAgentRuntime(AgentRuntime):
logger.info(
f"Calling message handler for {agent_id.type} with message type {type(message_envelope.message).__name__} published by {sender_name}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=None,
kind=MessageKind.PUBLISH,
delivery_stage=DeliveryStage.DELIVER,
)
)
message_context = MessageContext(
sender=message_envelope.sender,
topic_id=message_envelope.topic_id,
@ -386,20 +404,29 @@ class SingleThreadedAgentRuntime(AgentRuntime):
async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
with self._tracer_helper.trace_block("process", agent.id, parent=None):
with MessageHandlerContext.populate_context(agent.id):
try:
return await agent.on_message(
message_envelope.message,
ctx=message_context,
)
except BaseException as e:
logger.error(f"Error processing publish message for {agent.id}", exc_info=True)
event_logger.info(
MessageHandlerExceptionEvent(
payload=self._try_serialize(message_envelope.message),
handling_agent=agent.id,
exception=e,
)
)
raise
future = _on_message(agent, message_context)
responses.append(future)
await asyncio.gather(*responses)
except BaseException as e:
# Ignore cancelled errors from logs
if isinstance(e, CancelledError):
return
logger.error("Error processing publish message", exc_info=True)
except BaseException:
# Ignore exceptions raised during publishing. We've already logged them above.
pass
finally:
self._message_queue.task_done()
# TODO if responses are given for a publish
@ -414,18 +441,18 @@ class SingleThreadedAgentRuntime(AgentRuntime):
logger.info(
f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}"
)
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
self._message_queue.task_done()
event_logger.info(
MessageEvent(
payload=self._try_serialize(message_envelope.message),
sender=message_envelope.sender,
receiver=message_envelope.recipient,
kind=MessageKind.RESPOND,
delivery_stage=DeliveryStage.DELIVER,
)
)
if not message_envelope.future.cancelled():
message_envelope.future.set_result(message_envelope.message)
self._message_queue.task_done()
@deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.")
async def process_next(self) -> None:
@ -453,6 +480,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.DIRECT,
)
)
future.set_exception(MessageDroppedException())
return
@ -463,6 +498,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
case PublishMessageEnvelope(
message=message,
sender=sender,
topic_id=topic_id,
):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
@ -477,7 +513,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
# TODO log message dropped
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=topic_id,
kind=MessageKind.PUBLISH,
)
)
return
message_envelope.message = temp_message
@ -495,6 +538,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
future.set_exception(e)
return
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
event_logger.info(
MessageDroppedEvent(
payload=self._try_serialize(message),
sender=sender,
receiver=recipient,
kind=MessageKind.RESPOND,
)
)
future.set_exception(MessageDroppedException())
return
message_envelope.message = temp_message
@ -603,6 +654,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
agent_id: AgentId,
) -> T:
with AgentInstantiationContext.populate_context((self, agent_id)):
try:
if len(inspect.signature(agent_factory).parameters) == 0:
factory_one = cast(Callable[[], T], agent_factory)
agent = factory_one()
@ -621,6 +673,16 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return agent
except BaseException as e:
event_logger.info(
AgentConstructionExceptionEvent(
agent_id=agent_id,
exception=e,
)
)
logger.error(f"Error constructing agent {agent_id}", exc_info=True)
raise
async def _get_agent(self, agent_id: AgentId) -> Agent:
if agent_id in self._instantiated_agents:
return self._instantiated_agents[agent_id]
@ -666,3 +728,12 @@ class SingleThreadedAgentRuntime(AgentRuntime):
def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None:
self._serialization_registry.add_serializer(serializer)
def _try_serialize(self, message: Any) -> str:
try:
type_name = self._serialization_registry.type_name(message)
return self._serialization_registry.serialize(
message, type_name=type_name, data_content_type=JSON_DATA_CONTENT_TYPE
).decode("utf-8")
except ValueError:
return "Message could not be serialized"

View File

@ -2,7 +2,8 @@ import json
from enum import Enum
from typing import Any, cast
from autogen_core import AgentId
from ._agent_id import AgentId
from ._topic import TopicId
class LLMCallEvent:
@ -57,9 +58,9 @@ class MessageEvent:
def __init__(
self,
*,
payload: Any,
payload: str,
sender: AgentId | None,
receiver: AgentId | None,
receiver: AgentId | TopicId | None,
kind: MessageKind,
delivery_stage: DeliveryStage,
**kwargs: Any,
@ -68,18 +69,70 @@ class MessageEvent:
self.kwargs["payload"] = payload
self.kwargs["sender"] = None if sender is None else str(sender)
self.kwargs["receiver"] = None if receiver is None else str(receiver)
self.kwargs["kind"] = kind
self.kwargs["delivery_stage"] = delivery_stage
self.kwargs["kind"] = str(kind)
self.kwargs["delivery_stage"] = str(delivery_stage)
self.kwargs["type"] = "Message"
@property
def prompt_tokens(self) -> int:
return cast(int, self.kwargs["prompt_tokens"])
@property
def completion_tokens(self) -> int:
return cast(int, self.kwargs["completion_tokens"])
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class MessageDroppedEvent:
def __init__(
self,
*,
payload: str,
sender: AgentId | None,
receiver: AgentId | TopicId | None,
kind: MessageKind,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["payload"] = payload
self.kwargs["sender"] = None if sender is None else str(sender)
self.kwargs["receiver"] = None if receiver is None else str(receiver)
self.kwargs["kind"] = str(kind)
self.kwargs["type"] = "MessageDropped"
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class MessageHandlerExceptionEvent:
def __init__(
self,
*,
payload: str,
handling_agent: AgentId,
exception: BaseException,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["payload"] = payload
self.kwargs["handling_agent"] = str(handling_agent)
self.kwargs["exception"] = str(exception)
self.kwargs["type"] = "MessageHandlerException"
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)
class AgentConstructionExceptionEvent:
def __init__(
self,
*,
agent_id: AgentId,
exception: BaseException,
**kwargs: Any,
) -> None:
self.kwargs = kwargs
self.kwargs["agent_id"] = str(agent_id)
self.kwargs["exception"] = str(exception)
self.kwargs["type"] = "AgentConstructionException"
# This must output the event in a json serializable format
def __str__(self) -> str:
return json.dumps(self.kwargs)

View File

@ -78,7 +78,7 @@ async def test_message_handler_router() -> None:
@dataclass
class TestMessage:
class MyMessage:
value: str
@ -89,15 +89,15 @@ class RoutedAgentMessageCustomMatch(RoutedAgent):
self.handler_two_called = False
@staticmethod
def match_one(message: TestMessage, ctx: MessageContext) -> bool:
def match_one(message: MyMessage, ctx: MessageContext) -> bool:
return message.value == "one"
@message_handler(match=match_one)
async def handler_one(self, message: TestMessage, ctx: MessageContext) -> None:
async def handler_one(self, message: MyMessage, ctx: MessageContext) -> None:
self.handler_one_called = True
@message_handler(match=cast(Callable[[TestMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None:
@message_handler(match=cast(Callable[[MyMessage, MessageContext], bool], lambda msg, ctx: msg.value == "two")) # type: ignore
async def handler_two(self, message: MyMessage, ctx: MessageContext) -> None:
self.handler_two_called = True
@ -113,14 +113,14 @@ async def test_routed_agent_message_matching() -> None:
assert agent.handler_two_called is False
runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id)
await runtime.send_message(MyMessage("one"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
assert agent.handler_one_called is True
assert agent.handler_two_called is False
runtime.start()
await runtime.send_message(TestMessage("two"), recipient=agent_id)
await runtime.send_message(MyMessage("two"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
assert agent.handler_one_called is True
@ -133,11 +133,11 @@ class EventAgent(RoutedAgent):
self.num_calls = [0, 0]
@event(match=lambda msg, ctx: msg.value == "one") # type: ignore
async def on_event_one(self, message: TestMessage, ctx: MessageContext) -> None:
async def on_event_one(self, message: MyMessage, ctx: MessageContext) -> None:
self.num_calls[0] += 1
@event(match=lambda msg, ctx: msg.value == "two") # type: ignore
async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None:
async def on_event_two(self, message: MyMessage, ctx: MessageContext) -> None:
self.num_calls[1] += 1
@ -150,7 +150,7 @@ async def test_event() -> None:
# Send a broadcast message.
runtime.start()
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1
@ -158,7 +158,7 @@ async def test_event() -> None:
# Send another broadcast message.
runtime.start()
await runtime.publish_message(TestMessage("two"), topic_id=TopicId("default", "default"))
await runtime.publish_message(MyMessage("two"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1
@ -166,7 +166,7 @@ async def test_event() -> None:
# Send an RPC message, expect no change.
runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id)
await runtime.send_message(MyMessage("one"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent)
assert agent.num_calls[0] == 1
@ -179,12 +179,12 @@ class RPCAgent(RoutedAgent):
self.num_calls = [0, 0]
@rpc(match=lambda msg, ctx: msg.value == "one") # type: ignore
async def on_rpc_one(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
async def on_rpc_one(self, message: MyMessage, ctx: MessageContext) -> MyMessage:
self.num_calls[0] += 1
return message
@rpc(match=lambda msg, ctx: msg.value == "two") # type: ignore
async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMessage:
async def on_rpc_two(self, message: MyMessage, ctx: MessageContext) -> MyMessage:
self.num_calls[1] += 1
return message
@ -198,7 +198,7 @@ async def test_rpc() -> None:
# Send an RPC message.
runtime.start()
await runtime.send_message(TestMessage("one"), recipient=agent_id)
await runtime.send_message(MyMessage("one"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1
@ -206,7 +206,7 @@ async def test_rpc() -> None:
# Send another RPC message.
runtime.start()
await runtime.send_message(TestMessage("two"), recipient=agent_id)
await runtime.send_message(MyMessage("two"), recipient=agent_id)
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1
@ -214,7 +214,7 @@ async def test_rpc() -> None:
# Send a broadcast message, expect no change.
runtime.start()
await runtime.publish_message(TestMessage("one"), topic_id=TopicId("default", "default"))
await runtime.publish_message(MyMessage("one"), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent)
assert agent.num_calls[0] == 1

View File

@ -20,10 +20,10 @@ from autogen_test_utils import (
MessageType,
NoopAgent,
)
from autogen_test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider
from opentelemetry.sdk.trace import TracerProvider
test_exporter = TestExporter()
test_exporter = MyTestExporter()
@pytest.fixture
@ -88,7 +88,7 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non
@pytest.mark.asyncio
async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptureFixture) -> None:
async def test_register_receives_publish_with_construction(caplog: pytest.LogCaptureFixture) -> None:
runtime = SingleThreadedAgentRuntime()
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
@ -103,8 +103,9 @@ async def test_register_receives_publish_with_exception(caplog: pytest.LogCaptur
runtime.start()
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
# Check if logger has the exception.
assert any("Error processing publish message" in e.message for e in caplog.records)
assert any("Error constructing agent" in e.message for e in caplog.records)
@pytest.mark.asyncio

View File

@ -339,7 +339,9 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> Any:
# TODO: use message_id
if not self._running:
raise ValueError("Runtime must be running when sending message.")
if self._host_connection is None:

View File

@ -7,8 +7,10 @@ name = "autogen-test-utils"
version = "0.0.0"
license = {file = "LICENSE-CODE"}
requires-python = ">=3.10"
dependencies = ["autogen-core",
dependencies = [
"autogen-core",
"pytest",
"opentelemetry-sdk>=1.27.0",
]
[tool.ruff]

View File

@ -1,10 +1,11 @@
from typing import List, Sequence
import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult
class TestExporter(SpanExporter):
class MyTestExporter(SpanExporter):
def __init__(self) -> None:
self.exported_spans: List[ReadableSpan] = []
@ -24,7 +25,7 @@ class TestExporter(SpanExporter):
return self.exported_spans
def get_test_tracer_provider(exporter: TestExporter) -> TracerProvider:
def get_test_tracer_provider(exporter: MyTestExporter) -> TracerProvider:
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
return tracer_provider

8
python/uv.lock generated
View File

@ -627,10 +627,16 @@ version = "0.0.0"
source = { editable = "packages/autogen-test-utils" }
dependencies = [
{ name = "autogen-core" },
{ name = "opentelemetry-sdk" },
{ name = "pytest" },
]
[package.metadata]
requires-dist = [{ name = "autogen-core", editable = "packages/autogen-core" }]
requires-dist = [
{ name = "autogen-core", editable = "packages/autogen-core" },
{ name = "opentelemetry-sdk", specifier = ">=1.27.0" },
{ name = "pytest" },
]
[[package]]
name = "autogenstudio"