2024-06-28 10:22:44 -04:00
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
from agnext.application import SingleThreadedAgentRuntime
|
|
|
|
|
2024-08-20 14:41:24 -04:00
|
|
|
from agnext.components._type_subscription import TypeSubscription
|
2024-08-16 23:14:09 -04:00
|
|
|
from agnext.core import AgentRuntime, AgentId
|
2024-06-28 10:22:44 -04:00
|
|
|
|
|
|
|
from agnext.components import ClosureAgent
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
2024-08-16 23:14:09 -04:00
|
|
|
from agnext.core import MessageContext
|
2024-08-20 14:41:24 -04:00
|
|
|
from agnext.core import TopicId
|
2024-08-16 23:14:09 -04:00
|
|
|
|
2024-06-28 10:22:44 -04:00
|
|
|
@dataclass
|
|
|
|
class Message:
|
|
|
|
content: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_register_receives_publish() -> None:
|
|
|
|
runtime = SingleThreadedAgentRuntime()
|
|
|
|
|
|
|
|
queue = asyncio.Queue[tuple[str, str]]()
|
|
|
|
|
2024-08-16 23:14:09 -04:00
|
|
|
async def log_message(_runtime: AgentRuntime, id: AgentId, message: Message, ctx: MessageContext) -> None:
|
2024-08-07 13:25:44 -04:00
|
|
|
key = id.key
|
|
|
|
await queue.put((key, message.content))
|
2024-06-28 10:22:44 -04:00
|
|
|
|
2024-08-20 14:41:24 -04:00
|
|
|
await runtime.register("name", lambda: ClosureAgent("my_agent", log_message))
|
|
|
|
await runtime.add_subscription(TypeSubscription("default", "name"))
|
|
|
|
topic_id = TopicId("default", "default")
|
2024-08-21 13:59:59 -07:00
|
|
|
runtime.start()
|
2024-08-20 14:41:24 -04:00
|
|
|
|
|
|
|
await runtime.publish_message(Message("first message"), topic_id=topic_id)
|
|
|
|
await runtime.publish_message(Message("second message"), topic_id=topic_id)
|
|
|
|
await runtime.publish_message(Message("third message"), topic_id=topic_id)
|
|
|
|
|
2024-06-28 10:22:44 -04:00
|
|
|
|
2024-08-21 13:59:59 -07:00
|
|
|
await runtime.stop_when_idle()
|
2024-06-28 10:22:44 -04:00
|
|
|
|
|
|
|
assert queue.qsize() == 3
|
|
|
|
assert queue.get_nowait() == ("default", "first message")
|
|
|
|
assert queue.get_nowait() == ("default", "second message")
|
|
|
|
assert queue.get_nowait() == ("default", "third message")
|
|
|
|
assert queue.empty()
|