mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 22:18:53 +00:00
fix!: Fix SingleThreadedAgentRuntime busy loop (#4855)
* Fix high cpu usage * Use queue for shutdown * mypy fixes * formatting * missing import
This commit is contained in:
parent
49b52db6ea
commit
190fcd15ed
264
python/packages/autogen-core/src/autogen_core/_queue.py
Normal file
264
python/packages/autogen-core/src/autogen_core/_queue.py
Normal file
@ -0,0 +1,264 @@
|
||||
# Copy of Asyncio queue: https://github.com/python/cpython/blob/main/Lib/asyncio/queues.py
|
||||
# So that shutdown can be used in <3.13
|
||||
# Modified to work outside of the asyncio package
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
class _LoopBoundMixin:
|
||||
_loop = None
|
||||
|
||||
def _get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if self._loop is None:
|
||||
with _global_lock:
|
||||
if self._loop is None:
|
||||
self._loop = loop
|
||||
if loop is not self._loop:
|
||||
raise RuntimeError(f"{self!r} is bound to a different event loop")
|
||||
return loop
|
||||
|
||||
|
||||
class QueueShutDown(Exception):
|
||||
"""Raised when putting on to or getting from a shut-down Queue."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Queue(_LoopBoundMixin, Generic[T]):
|
||||
def __init__(self, maxsize: int = 0):
|
||||
self._maxsize = maxsize
|
||||
self._getters = collections.deque[asyncio.Future[None]]()
|
||||
self._putters = collections.deque[asyncio.Future[None]]()
|
||||
self._unfinished_tasks = 0
|
||||
self._finished = asyncio.Event()
|
||||
self._finished.set()
|
||||
self._queue = collections.deque[T]()
|
||||
self._is_shutdown = False
|
||||
|
||||
# These three are overridable in subclasses.
|
||||
|
||||
def _get(self) -> T:
|
||||
return self._queue.popleft()
|
||||
|
||||
def _put(self, item: T) -> None:
|
||||
self._queue.append(item)
|
||||
|
||||
# End of the overridable methods.
|
||||
|
||||
def _wakeup_next(self, waiters: collections.deque[asyncio.Future[None]]) -> None:
|
||||
# Wake up the next waiter (if any) that isn't cancelled.
|
||||
while waiters:
|
||||
waiter = waiters.popleft()
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
break
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{type(self).__name__} at {id(self):#x} {self._format()}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{type(self).__name__} {self._format()}>"
|
||||
|
||||
def _format(self) -> str:
|
||||
result = f"maxsize={self._maxsize!r}"
|
||||
if getattr(self, "_queue", None):
|
||||
result += f" _queue={list(self._queue)!r}"
|
||||
if self._getters:
|
||||
result += f" _getters[{len(self._getters)}]"
|
||||
if self._putters:
|
||||
result += f" _putters[{len(self._putters)}]"
|
||||
if self._unfinished_tasks:
|
||||
result += f" tasks={self._unfinished_tasks}"
|
||||
if self._is_shutdown:
|
||||
result += " shutdown"
|
||||
return result
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Number of items in the queue."""
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def maxsize(self) -> int:
|
||||
"""Number of items allowed in the queue."""
|
||||
return self._maxsize
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Return True if the queue is empty, False otherwise."""
|
||||
return not self._queue
|
||||
|
||||
def full(self) -> bool:
|
||||
"""Return True if there are maxsize items in the queue.
|
||||
|
||||
Note: if the Queue was initialized with maxsize=0 (the default),
|
||||
then full() is never True.
|
||||
"""
|
||||
if self._maxsize <= 0:
|
||||
return False
|
||||
else:
|
||||
return self.qsize() >= self._maxsize
|
||||
|
||||
async def put(self, item: T) -> None:
|
||||
"""Put an item into the queue.
|
||||
|
||||
Put an item into the queue. If the queue is full, wait until a free
|
||||
slot is available before adding item.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down.
|
||||
"""
|
||||
while self.full():
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
putter = self._get_loop().create_future()
|
||||
self._putters.append(putter)
|
||||
try:
|
||||
await putter
|
||||
except:
|
||||
putter.cancel() # Just in case putter is not done yet.
|
||||
try:
|
||||
# Clean self._putters from canceled putters.
|
||||
self._putters.remove(putter)
|
||||
except ValueError:
|
||||
# The putter could be removed from self._putters by a
|
||||
# previous get_nowait call or a shutdown call.
|
||||
pass
|
||||
if not self.full() and not putter.cancelled():
|
||||
# We were woken up by get_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._putters)
|
||||
raise
|
||||
return self.put_nowait(item)
|
||||
|
||||
def put_nowait(self, item: T) -> None:
|
||||
"""Put an item into the queue without blocking.
|
||||
|
||||
If no free slot is immediately available, raise QueueFull.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down.
|
||||
"""
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
if self.full():
|
||||
raise asyncio.QueueFull
|
||||
self._put(item)
|
||||
self._unfinished_tasks += 1
|
||||
self._finished.clear()
|
||||
self._wakeup_next(self._getters)
|
||||
|
||||
async def get(self) -> T:
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If queue is empty, wait until an item is available.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down and is empty, or
|
||||
if the queue has been shut down immediately.
|
||||
"""
|
||||
while self.empty():
|
||||
if self._is_shutdown and self.empty():
|
||||
raise QueueShutDown
|
||||
getter = self._get_loop().create_future()
|
||||
self._getters.append(getter)
|
||||
try:
|
||||
await getter
|
||||
except:
|
||||
getter.cancel() # Just in case getter is not done yet.
|
||||
try:
|
||||
# Clean self._getters from canceled getters.
|
||||
self._getters.remove(getter)
|
||||
except ValueError:
|
||||
# The getter could be removed from self._getters by a
|
||||
# previous put_nowait call, or a shutdown call.
|
||||
pass
|
||||
if not self.empty() and not getter.cancelled():
|
||||
# We were woken up by put_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._getters)
|
||||
raise
|
||||
return self.get_nowait()
|
||||
|
||||
def get_nowait(self) -> T:
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
Return an item if one is immediately available, else raise QueueEmpty.
|
||||
|
||||
Raises QueueShutDown if the queue has been shut down and is empty, or
|
||||
if the queue has been shut down immediately.
|
||||
"""
|
||||
if self.empty():
|
||||
if self._is_shutdown:
|
||||
raise QueueShutDown
|
||||
raise asyncio.QueueEmpty
|
||||
item = self._get()
|
||||
self._wakeup_next(self._putters)
|
||||
return item
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""Indicate that a formerly enqueued task is complete.
|
||||
|
||||
Used by queue consumers. For each get() used to fetch a task,
|
||||
a subsequent call to task_done() tells the queue that the processing
|
||||
on the task is complete.
|
||||
|
||||
If a join() is currently blocking, it will resume when all items have
|
||||
been processed (meaning that a task_done() call was received for every
|
||||
item that had been put() into the queue).
|
||||
|
||||
shutdown(immediate=True) calls task_done() for each remaining item in
|
||||
the queue.
|
||||
|
||||
Raises ValueError if called more times than there were items placed in
|
||||
the queue.
|
||||
"""
|
||||
if self._unfinished_tasks <= 0:
|
||||
raise ValueError("task_done() called too many times")
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
|
||||
async def join(self) -> None:
|
||||
"""Block until all items in the queue have been gotten and processed.
|
||||
|
||||
The count of unfinished tasks goes up whenever an item is added to the
|
||||
queue. The count goes down whenever a consumer calls task_done() to
|
||||
indicate that the item was retrieved and all work on it is complete.
|
||||
When the count of unfinished tasks drops to zero, join() unblocks.
|
||||
"""
|
||||
if self._unfinished_tasks > 0:
|
||||
await self._finished.wait()
|
||||
|
||||
def shutdown(self, immediate: bool = False) -> None:
|
||||
"""Shut-down the queue, making queue gets and puts raise QueueShutDown.
|
||||
|
||||
By default, gets will only raise once the queue is empty. Set
|
||||
'immediate' to True to make gets raise immediately instead.
|
||||
|
||||
All blocked callers of put() and get() will be unblocked. If
|
||||
'immediate', a task is marked as done for each item remaining in
|
||||
the queue, which may unblock callers of join().
|
||||
"""
|
||||
self._is_shutdown = True
|
||||
if immediate:
|
||||
while not self.empty():
|
||||
self._get()
|
||||
if self._unfinished_tasks > 0:
|
||||
self._unfinished_tasks -= 1
|
||||
if self._unfinished_tasks == 0:
|
||||
self._finished.set()
|
||||
# All getters need to re-check queue-empty to raise ShutDown
|
||||
while self._getters:
|
||||
getter = self._getters.popleft()
|
||||
if not getter.done():
|
||||
getter.set_result(None)
|
||||
while self._putters:
|
||||
putter = self._putters.popleft()
|
||||
if not putter.done():
|
||||
putter.set_result(None)
|
||||
@ -3,17 +3,24 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import CancelledError, Future, Task
|
||||
from asyncio import CancelledError, Future, Queue, Task
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||
|
||||
from opentelemetry.trace import TracerProvider
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from asyncio import Queue, QueueShutDown
|
||||
else:
|
||||
from ._queue import Queue, QueueShutDown # type: ignore
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ._agent import Agent
|
||||
from ._agent_id import AgentId
|
||||
from ._agent_instantiation import AgentInstantiationContext
|
||||
@ -100,48 +107,36 @@ class Counter:
|
||||
|
||||
|
||||
class RunContext:
|
||||
class RunState(Enum):
|
||||
RUNNING = 0
|
||||
CANCELLED = 1
|
||||
UNTIL_IDLE = 2
|
||||
|
||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
self._run_state = RunContext.RunState.RUNNING
|
||||
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
|
||||
self._run_task = asyncio.create_task(self._run())
|
||||
self._lock = asyncio.Lock()
|
||||
self._stopped = asyncio.Event()
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._end_condition():
|
||||
return
|
||||
if self._stopped.is_set():
|
||||
return
|
||||
|
||||
await self._runtime.process_next()
|
||||
await self._runtime._process_next() # type: ignore
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.CANCELLED
|
||||
self._end_condition = self._stop_when_cancelled
|
||||
self._stopped.set()
|
||||
self._runtime._message_queue.shutdown(immediate=True) # type: ignore
|
||||
await self._run_task
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||
self._end_condition = self._stop_when_idle
|
||||
await self._runtime._message_queue.join() # type: ignore
|
||||
self._stopped.set()
|
||||
self._runtime._message_queue.shutdown(immediate=True) # type: ignore
|
||||
await self._run_task
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
async with self._lock:
|
||||
self._end_condition = condition
|
||||
await self._run_task
|
||||
async def stop_when(self, condition: Callable[[], bool], check_period: float = 1.0) -> None:
|
||||
async def check_condition() -> None:
|
||||
while not condition():
|
||||
await asyncio.sleep(check_period)
|
||||
await self.stop()
|
||||
|
||||
def _stop_when_cancelled(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.CANCELLED
|
||||
|
||||
def _stop_when_idle(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
|
||||
await asyncio.create_task(check_condition())
|
||||
|
||||
|
||||
def _warn_if_none(value: Any, handler_name: str) -> None:
|
||||
@ -169,28 +164,23 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
tracer_provider: TracerProvider | None = None,
|
||||
) -> None:
|
||||
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
|
||||
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
|
||||
self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue()
|
||||
# (namespace, type) -> List[AgentId]
|
||||
self._agent_factories: Dict[
|
||||
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
|
||||
] = {}
|
||||
self._instantiated_agents: Dict[AgentId, Agent] = {}
|
||||
self._intervention_handlers = intervention_handlers
|
||||
self._outstanding_tasks = Counter()
|
||||
self._background_tasks: Set[Task[Any]] = set()
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._run_context: RunContext | None = None
|
||||
self._serialization_registry = SerializationRegistry()
|
||||
|
||||
@property
|
||||
def unprocessed_messages(
|
||||
def unprocessed_messages_count(
|
||||
self,
|
||||
) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]:
|
||||
return self._message_queue
|
||||
|
||||
@property
|
||||
def outstanding_tasks(self) -> int:
|
||||
return self._outstanding_tasks.get()
|
||||
) -> int:
|
||||
return self._message_queue.qsize()
|
||||
|
||||
@property
|
||||
def _known_agent_names(self) -> Set[str]:
|
||||
@ -231,7 +221,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
content = message.__dict__ if hasattr(message, "__dict__") else message
|
||||
logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}")
|
||||
|
||||
self._message_queue.append(
|
||||
await self._message_queue.put(
|
||||
SendMessageEnvelope(
|
||||
message=message,
|
||||
recipient=recipient,
|
||||
@ -279,7 +269,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
# )
|
||||
# )
|
||||
|
||||
self._message_queue.append(
|
||||
await self._message_queue.put(
|
||||
PublishMessageEnvelope(
|
||||
message=message,
|
||||
cancellation_token=cancellation_token,
|
||||
@ -340,14 +330,14 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
except CancelledError as e:
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_exception(e)
|
||||
self._outstanding_tasks.decrement()
|
||||
self._message_queue.task_done()
|
||||
return
|
||||
except BaseException as e:
|
||||
message_envelope.future.set_exception(e)
|
||||
self._outstanding_tasks.decrement()
|
||||
self._message_queue.task_done()
|
||||
return
|
||||
|
||||
self._message_queue.append(
|
||||
await self._message_queue.put(
|
||||
ResponseMessageEnvelope(
|
||||
message=response,
|
||||
future=message_envelope.future,
|
||||
@ -356,7 +346,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
metadata=get_telemetry_envelope_metadata(),
|
||||
)
|
||||
)
|
||||
self._outstanding_tasks.decrement()
|
||||
self._message_queue.task_done()
|
||||
|
||||
async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
|
||||
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
|
||||
@ -411,7 +401,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return
|
||||
logger.error("Error processing publish message", exc_info=True)
|
||||
finally:
|
||||
self._outstanding_tasks.decrement()
|
||||
self._message_queue.task_done()
|
||||
# TODO if responses are given for a publish
|
||||
|
||||
async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
|
||||
@ -433,18 +423,21 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
# delivery_stage=DeliveryStage.DELIVER,
|
||||
# )
|
||||
# )
|
||||
self._outstanding_tasks.decrement()
|
||||
self._message_queue.task_done()
|
||||
if not message_envelope.future.cancelled():
|
||||
message_envelope.future.set_result(message_envelope.message)
|
||||
|
||||
@deprecated("Manually stepping the runtime processing is deprecated. Use start() instead.")
|
||||
async def process_next(self) -> None:
|
||||
await self._process_next()
|
||||
|
||||
async def _process_next(self) -> None:
|
||||
"""Process the next message in the queue."""
|
||||
|
||||
if len(self._message_queue) == 0:
|
||||
# Yield control to the event loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
message_envelope = await self._message_queue.get()
|
||||
except QueueShutDown:
|
||||
return
|
||||
message_envelope = self._message_queue.pop(0)
|
||||
|
||||
match message_envelope:
|
||||
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
|
||||
@ -464,7 +457,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_send(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
@ -489,7 +481,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
return
|
||||
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_publish(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
@ -507,7 +498,6 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
future.set_exception(MessageDroppedException())
|
||||
return
|
||||
message_envelope.message = temp_message
|
||||
self._outstanding_tasks.increment()
|
||||
task = asyncio.create_task(self._process_response(message_envelope))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
@ -515,37 +505,59 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
# Yield control to the message loop to allow other tasks to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@property
|
||||
def idle(self) -> bool:
|
||||
return len(self._message_queue) == 0 and self._outstanding_tasks.get() == 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the runtime message processing loop."""
|
||||
"""Start the runtime message processing loop. This runs in a background task.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from autogen_core import SingleThreadedAgentRuntime
|
||||
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
|
||||
"""
|
||||
if self._run_context is not None:
|
||||
raise RuntimeError("Runtime is already started")
|
||||
self._run_context = RunContext(self)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the runtime message processing loop."""
|
||||
"""Immediately stop the runtime message processing loop. The currently processing message will be completed, but all others following it will be discarded."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop()
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
"""Stop the runtime message processing loop when there is
|
||||
no outstanding message being processed or queued."""
|
||||
no outstanding message being processed or queued. This is the most common way to stop the runtime."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when_idle()
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
"""Stop the runtime message processing loop when the condition is met."""
|
||||
"""Stop the runtime message processing loop when the condition is met.
|
||||
|
||||
.. caution::
|
||||
|
||||
This method is not recommended to be used, and is here for legacy
|
||||
reasons. It will spawn a busy loop to continually check the
|
||||
condition. It is much more efficient to call `stop_when_idle` or
|
||||
`stop` instead. If you need to stop the runtime based on a
|
||||
condition, consider using a background task and asyncio.Event to
|
||||
signal when the condition is met and the background task should call
|
||||
stop.
|
||||
|
||||
"""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when(condition)
|
||||
self._run_context = None
|
||||
self._message_queue = Queue()
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return (await self._get_agent(agent)).metadata
|
||||
|
||||
@ -71,10 +71,10 @@ async def test_cancellation_with_token() -> None:
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
while runtime.unprocessed_messages_count == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await runtime.process_next()
|
||||
await runtime._process_next() # type: ignore
|
||||
|
||||
token.cancel()
|
||||
|
||||
@ -104,10 +104,10 @@ async def test_nested_cancellation_only_outer_called() -> None:
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
while runtime.unprocessed_messages_count == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await runtime.process_next()
|
||||
await runtime._process_next() # type: ignore
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
@ -140,12 +140,12 @@ async def test_nested_cancellation_inner_called() -> None:
|
||||
response = asyncio.create_task(runtime.send_message(MessageType(), nested_id, cancellation_token=token))
|
||||
assert not response.done()
|
||||
|
||||
while len(runtime.unprocessed_messages) == 0:
|
||||
while runtime.unprocessed_messages_count == 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await runtime.process_next()
|
||||
await runtime._process_next() # type: ignore
|
||||
# allow the inner agent to process
|
||||
await runtime.process_next()
|
||||
await runtime._process_next() # type: ignore
|
||||
token.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user