2024-07-01 11:53:45 -04:00
from __future__ import annotations
2024-05-15 12:31:13 -04:00
import asyncio
2024-06-18 14:53:18 -04:00
import inspect
2024-06-04 10:17:04 -04:00
import logging
2024-06-18 14:53:18 -04:00
import threading
2024-08-02 11:02:45 -04:00
import warnings
2024-07-01 11:53:45 -04:00
from asyncio import CancelledError , Future , Task
2024-06-18 14:53:18 -04:00
from collections import defaultdict
2024-06-04 10:17:04 -04:00
from collections . abc import Sequence
2024-05-15 12:31:13 -04:00
from dataclasses import dataclass
2024-07-01 11:53:45 -04:00
from enum import Enum
2024-07-23 16:38:37 -07:00
from typing import Any , Awaitable , Callable , DefaultDict , Dict , List , Mapping , ParamSpec , Set , Type , TypeVar , cast
2024-05-15 12:31:13 -04:00
2024-06-21 10:47:51 -04:00
from . . core import (
2024-07-08 16:45:14 -04:00
MESSAGE_TYPE_REGISTRY ,
2024-06-21 10:47:51 -04:00
Agent ,
AgentId ,
2024-08-02 11:02:45 -04:00
AgentInstantiationContext ,
2024-06-21 10:47:51 -04:00
AgentMetadata ,
AgentProxy ,
AgentRuntime ,
CancellationToken ,
)
2024-05-27 17:10:56 -04:00
from . . core . exceptions import MessageDroppedException
from . . core . intervention import DropMessage , InterventionHandler
2024-05-15 12:31:13 -04:00
2024-06-04 10:17:04 -04:00
logger = logging . getLogger ( " agnext " )
2024-06-07 13:33:51 -07:00
event_logger = logging . getLogger ( " agnext.events " )
2024-06-04 10:17:04 -04:00
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass ( kw_only = True )
2024-05-26 08:45:02 -04:00
class PublishMessageEnvelope :
""" A message envelope for publishing messages to all agents that can handle
2024-05-17 14:59:00 -07:00
the message of the type T . """
2024-05-23 16:00:05 -04:00
message : Any
2024-05-20 13:32:08 -06:00
cancellation_token : CancellationToken
2024-06-18 14:53:18 -04:00
sender : AgentId | None
namespace : str
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass ( kw_only = True )
2024-05-23 16:00:05 -04:00
class SendMessageEnvelope :
2024-05-17 14:59:00 -07:00
""" A message envelope for sending a message to a specific agent that can handle
the message of the type T . """
2024-05-23 16:00:05 -04:00
message : Any
2024-06-18 14:53:18 -04:00
sender : AgentId | None
recipient : AgentId
2024-05-26 08:45:02 -04:00
future : Future [ Any ]
2024-05-20 13:32:08 -06:00
cancellation_token : CancellationToken
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass ( kw_only = True )
2024-05-23 16:00:05 -04:00
class ResponseMessageEnvelope :
2024-05-17 14:59:00 -07:00
""" A message envelope for sending a response to a message. """
2024-05-23 16:00:05 -04:00
message : Any
future : Future [ Any ]
2024-06-18 14:53:18 -04:00
sender : AgentId
recipient : AgentId | None
P = ParamSpec ( " P " )
T = TypeVar ( " T " , bound = Agent )
class Counter :
def __init__ ( self ) - > None :
self . _count : int = 0
self . threadLock = threading . Lock ( )
def increment ( self ) - > None :
self . threadLock . acquire ( )
self . _count + = 1
self . threadLock . release ( )
def get ( self ) - > int :
return self . _count
def decrement ( self ) - > None :
self . threadLock . acquire ( )
self . _count - = 1
self . threadLock . release ( )
2024-05-19 17:12:49 -06:00
2024-07-01 11:53:45 -04:00
class RunContext :
class RunState ( Enum ) :
RUNNING = 0
CANCELLED = 1
UNTIL_IDLE = 2
def __init__ ( self , runtime : SingleThreadedAgentRuntime ) - > None :
self . _runtime = runtime
self . _run_state = RunContext . RunState . RUNNING
self . _run_task = asyncio . create_task ( self . _run ( ) )
self . _lock = asyncio . Lock ( )
async def _run ( self ) - > None :
while True :
async with self . _lock :
if self . _run_state == RunContext . RunState . CANCELLED :
return
elif self . _run_state == RunContext . RunState . UNTIL_IDLE :
if self . _runtime . idle :
return
await self . _runtime . process_next ( )
async def stop ( self ) - > None :
async with self . _lock :
self . _run_state = RunContext . RunState . CANCELLED
await self . _run_task
async def stop_when_idle ( self ) - > None :
async with self . _lock :
self . _run_state = RunContext . RunState . UNTIL_IDLE
await self . _run_task
2024-05-23 16:00:05 -04:00
class SingleThreadedAgentRuntime ( AgentRuntime ) :
2024-06-24 16:52:09 -04:00
def __init__ ( self , * , intervention_handler : InterventionHandler | None = None ) - > None :
2024-05-26 08:45:02 -04:00
self . _message_queue : List [ PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope ] = [ ]
2024-06-18 14:53:18 -04:00
# (namespace, type) -> List[AgentId]
2024-07-08 16:45:14 -04:00
self . _per_type_subscribers : DefaultDict [ tuple [ str , str ] , Set [ AgentId ] ] = defaultdict ( set )
2024-07-23 11:49:38 -07:00
self . _agent_factories : Dict [
str , Callable [ [ ] , Agent | Awaitable [ Agent ] ] | Callable [ [ AgentRuntime , AgentId ] , Agent | Awaitable [ Agent ] ]
] = { }
2024-06-18 14:53:18 -04:00
self . _instantiated_agents : Dict [ AgentId , Agent ] = { }
2024-06-24 16:52:09 -04:00
self . _intervention_handler = intervention_handler
2024-06-18 14:53:18 -04:00
self . _known_namespaces : set [ str ] = set ( )
self . _outstanding_tasks = Counter ( )
2024-07-01 11:53:45 -04:00
self . _background_tasks : Set [ Task [ Any ] ] = set ( )
2024-06-04 10:17:04 -04:00
@property
2024-06-07 13:33:51 -07:00
def unprocessed_messages (
self ,
) - > Sequence [ PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope ] :
2024-06-04 10:17:04 -04:00
return self . _message_queue
2024-06-18 14:53:18 -04:00
@property
def outstanding_tasks ( self ) - > int :
return self . _outstanding_tasks . get ( )
@property
def _known_agent_names ( self ) - > Set [ str ] :
return set ( self . _agent_factories . keys ( ) )
2024-05-15 12:31:13 -04:00
# Returns the response of the message
2024-06-27 13:40:12 -04:00
async def send_message (
2024-05-20 17:30:45 -06:00
self ,
2024-05-23 16:00:05 -04:00
message : Any ,
2024-06-18 14:53:18 -04:00
recipient : AgentId ,
2024-05-20 17:30:45 -06:00
* ,
2024-06-18 14:53:18 -04:00
sender : AgentId | None = None ,
2024-05-20 17:30:45 -06:00
cancellation_token : CancellationToken | None = None ,
2024-07-01 11:53:45 -04:00
) - > Any :
2024-05-20 13:32:08 -06:00
if cancellation_token is None :
cancellation_token = CancellationToken ( )
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
2024-05-23 16:00:05 -04:00
future = asyncio . get_event_loop ( ) . create_future ( )
2024-08-07 13:25:44 -04:00
if recipient . type not in self . _known_agent_names :
2024-05-23 16:00:05 -04:00
future . set_exception ( Exception ( " Recipient not found " ) )
2024-05-15 12:31:13 -04:00
2024-08-07 13:25:44 -04:00
if sender is not None and sender . key != recipient . key :
2024-06-18 14:53:18 -04:00
raise ValueError ( " Sender and recipient must be in the same namespace to communicate. " )
2024-08-07 13:25:44 -04:00
await self . _process_seen_namespace ( recipient . key )
2024-06-19 10:49:08 -04:00
2024-06-28 15:27:00 -07:00
content = message . __dict__ if hasattr ( message , " __dict__ " ) else message
2024-08-07 13:25:44 -04:00
logger . info ( f " Sending message of type { type ( message ) . __name__ } to { recipient . type } : { content } " )
2024-06-18 14:53:18 -04:00
2024-05-20 17:30:45 -06:00
self . _message_queue . append (
SendMessageEnvelope (
message = message ,
recipient = recipient ,
future = future ,
cancellation_token = cancellation_token ,
sender = sender ,
)
)
2024-05-20 13:32:08 -06:00
2024-07-01 11:53:45 -04:00
cancellation_token . link_future ( future )
return await future
2024-05-15 12:31:13 -04:00
2024-06-27 13:40:12 -04:00
async def publish_message (
2024-05-23 16:00:05 -04:00
self ,
message : Any ,
* ,
2024-06-18 14:53:18 -04:00
namespace : str | None = None ,
sender : AgentId | None = None ,
2024-05-23 16:00:05 -04:00
cancellation_token : CancellationToken | None = None ,
2024-06-27 13:40:12 -04:00
) - > None :
2024-05-20 13:32:08 -06:00
if cancellation_token is None :
cancellation_token = CancellationToken ( )
2024-06-28 15:27:00 -07:00
content = message . __dict__ if hasattr ( message , " __dict__ " ) else message
logger . info ( f " Publishing message of type { type ( message ) . __name__ } to all subscribers: { content } " )
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
2024-06-18 14:53:18 -04:00
if sender is None and namespace is None :
raise ValueError ( " Namespace must be provided if sender is not provided. " )
2024-08-07 13:25:44 -04:00
sender_namespace = sender . key if sender is not None else None
2024-06-18 14:53:18 -04:00
explicit_namespace = namespace
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace :
raise ValueError (
f " Explicit namespace { explicit_namespace } does not match sender namespace { sender_namespace } "
)
assert explicit_namespace is not None or sender_namespace is not None
namespace = cast ( str , explicit_namespace or sender_namespace )
2024-07-23 11:49:38 -07:00
await self . _process_seen_namespace ( namespace )
2024-06-19 10:49:08 -04:00
2024-05-20 17:30:45 -06:00
self . _message_queue . append (
2024-05-26 08:45:02 -04:00
PublishMessageEnvelope (
2024-05-23 16:00:05 -04:00
message = message ,
cancellation_token = cancellation_token ,
sender = sender ,
2024-06-18 14:53:18 -04:00
namespace = namespace ,
2024-05-20 17:30:45 -06:00
)
)
2024-05-23 16:00:05 -04:00
2024-07-23 11:49:38 -07:00
async def save_state ( self ) - > Mapping [ str , Any ] :
2024-05-27 20:25:25 -04:00
state : Dict [ str , Dict [ str , Any ] ] = { }
2024-06-18 14:53:18 -04:00
for agent_id in self . _instantiated_agents :
2024-07-23 11:49:38 -07:00
state [ str ( agent_id ) ] = dict ( ( await self . _get_agent ( agent_id ) ) . save_state ( ) )
2024-05-27 20:25:25 -04:00
return state
2024-07-23 11:49:38 -07:00
async def load_state ( self , state : Mapping [ str , Any ] ) - > None :
2024-06-18 14:53:18 -04:00
for agent_id_str in state :
agent_id = AgentId . from_str ( agent_id_str )
2024-08-07 13:25:44 -04:00
if agent_id . type in self . _known_agent_names :
2024-07-23 11:49:38 -07:00
( await self . _get_agent ( agent_id ) ) . load_state ( state [ str ( agent_id ) ] )
2024-05-27 20:25:25 -04:00
2024-05-23 16:00:05 -04:00
async def _process_send ( self , message_envelope : SendMessageEnvelope ) - > None :
2024-05-20 17:30:45 -06:00
recipient = message_envelope . recipient
2024-06-18 14:53:18 -04:00
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
2024-05-15 12:31:13 -04:00
2024-05-20 13:32:08 -06:00
try :
2024-08-07 13:25:44 -04:00
# TODO use id
sender_name = message_envelope . sender . type if message_envelope . sender is not None else " Unknown "
2024-06-04 10:17:04 -04:00
logger . info (
2024-06-18 14:53:18 -04:00
f " Calling message handler for { recipient } with message type { type ( message_envelope . message ) . __name__ } sent by { sender_name } "
2024-06-04 10:17:04 -04:00
)
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
2024-07-23 11:49:38 -07:00
recipient_agent = await self . _get_agent ( recipient )
2024-06-18 14:53:18 -04:00
response = await recipient_agent . on_message (
2024-05-23 16:00:05 -04:00
message_envelope . message ,
cancellation_token = message_envelope . cancellation_token ,
2024-05-20 13:32:08 -06:00
)
except BaseException as e :
message_envelope . future . set_exception ( e )
return
2024-05-26 08:45:02 -04:00
self . _message_queue . append (
ResponseMessageEnvelope (
message = response ,
future = message_envelope . future ,
sender = message_envelope . recipient ,
recipient = message_envelope . sender ,
2024-05-23 16:00:05 -04:00
)
2024-05-26 08:45:02 -04:00
)
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . decrement ( )
2024-05-23 16:00:05 -04:00
2024-05-26 08:45:02 -04:00
async def _process_publish ( self , message_envelope : PublishMessageEnvelope ) - > None :
2024-05-23 16:00:05 -04:00
responses : List [ Awaitable [ Any ] ] = [ ]
2024-06-18 14:53:18 -04:00
target_namespace = message_envelope . namespace
2024-07-08 16:45:14 -04:00
for agent_id in self . _per_type_subscribers [
( target_namespace , MESSAGE_TYPE_REGISTRY . type_name ( message_envelope . message ) )
] :
2024-08-07 13:25:44 -04:00
if message_envelope . sender is not None and agent_id . type == message_envelope . sender . type :
2024-05-28 16:21:40 -04:00
continue
2024-06-04 10:17:04 -04:00
2024-07-23 11:49:38 -07:00
sender_agent = (
await self . _get_agent ( message_envelope . sender ) if message_envelope . sender is not None else None
)
2024-08-07 16:08:13 -04:00
sender_name = str ( sender_agent . id ) if sender_agent is not None else " Unknown "
2024-06-04 10:17:04 -04:00
logger . info (
2024-08-07 13:25:44 -04:00
f " Calling message handler for { agent_id . type } with message type { type ( message_envelope . message ) . __name__ } published by { sender_name } "
2024-06-04 10:17:04 -04:00
)
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
2024-06-04 10:17:04 -04:00
2024-07-23 11:49:38 -07:00
agent = await self . _get_agent ( agent_id )
2024-05-23 16:00:05 -04:00
future = agent . on_message (
message_envelope . message ,
cancellation_token = message_envelope . cancellation_token ,
)
2024-05-19 17:12:49 -06:00
responses . append ( future )
2024-05-20 13:32:08 -06:00
try :
2024-05-26 08:45:02 -04:00
_all_responses = await asyncio . gather ( * responses )
2024-06-24 19:54:19 -04:00
except BaseException as e :
# Ignore cancelled errors from logs
if isinstance ( e , CancelledError ) :
return
2024-06-04 10:17:04 -04:00
logger . error ( " Error processing publish message " , exc_info = True )
2024-06-24 19:54:19 -04:00
finally :
self . _outstanding_tasks . decrement ( )
2024-05-26 08:45:02 -04:00
# TODO if responses are given for a publish
2024-05-19 17:12:49 -06:00
2024-05-23 16:00:05 -04:00
async def _process_response ( self , message_envelope : ResponseMessageEnvelope ) - > None :
2024-06-07 13:33:51 -07:00
content = (
message_envelope . message . __dict__
if hasattr ( message_envelope . message , " __dict__ " )
else message_envelope . message
)
2024-06-04 10:17:04 -04:00
logger . info (
2024-08-07 13:25:44 -04:00
f " Resolving response with message type { type ( message_envelope . message ) . __name__ } for recipient { message_envelope . recipient } from { message_envelope . sender . type } : { content } "
2024-06-04 10:17:04 -04:00
)
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . decrement ( )
2024-05-19 17:12:49 -06:00
message_envelope . future . set_result ( message_envelope . message )
2024-05-15 12:31:13 -04:00
async def process_next ( self ) - > None :
2024-06-27 11:46:06 -07:00
""" Process the next message in the queue. """
2024-05-17 14:59:00 -07:00
if len ( self . _message_queue ) == 0 :
2024-05-15 12:31:13 -04:00
# Yield control to the event loop to allow other tasks to run
await asyncio . sleep ( 0 )
return
2024-05-17 14:59:00 -07:00
message_envelope = self . _message_queue . pop ( 0 )
2024-05-15 12:31:13 -04:00
2024-05-17 14:59:00 -07:00
match message_envelope :
2024-05-20 17:30:45 -06:00
case SendMessageEnvelope ( message = message , sender = sender , recipient = recipient , future = future ) :
2024-06-24 16:52:09 -04:00
if self . _intervention_handler is not None :
2024-06-17 17:34:56 -07:00
try :
2024-06-24 16:52:09 -04:00
temp_message = await self . _intervention_handler . on_send (
message , sender = sender , recipient = recipient
)
2024-06-17 17:34:56 -07:00
except BaseException as e :
future . set_exception ( e )
return
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance ( temp_message , DropMessage ) :
future . set_exception ( MessageDroppedException ( ) )
return
2024-05-23 16:00:05 -04:00
message_envelope . message = temp_message
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . increment ( )
2024-07-01 11:53:45 -04:00
task = asyncio . create_task ( self . _process_send ( message_envelope ) )
self . _background_tasks . add ( task )
task . add_done_callback ( self . _background_tasks . discard )
2024-05-26 08:45:02 -04:00
case PublishMessageEnvelope (
2024-05-20 17:30:45 -06:00
message = message ,
sender = sender ,
) :
2024-06-24 16:52:09 -04:00
if self . _intervention_handler is not None :
2024-06-17 17:34:56 -07:00
try :
2024-06-24 16:52:09 -04:00
temp_message = await self . _intervention_handler . on_publish ( message , sender = sender )
2024-06-17 17:34:56 -07:00
except BaseException as e :
# TODO: we should raise the intervention exception to the publisher.
logger . error ( f " Exception raised in in intervention handler: { e } " , exc_info = True )
return
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance ( temp_message , DropMessage ) :
2024-05-26 08:45:02 -04:00
# TODO log message dropped
2024-05-20 17:30:45 -06:00
return
2024-05-23 16:00:05 -04:00
message_envelope . message = temp_message
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . increment ( )
2024-07-01 11:53:45 -04:00
task = asyncio . create_task ( self . _process_publish ( message_envelope ) )
self . _background_tasks . add ( task )
task . add_done_callback ( self . _background_tasks . discard )
2024-05-20 17:30:45 -06:00
case ResponseMessageEnvelope ( message = message , sender = sender , recipient = recipient , future = future ) :
2024-06-24 16:52:09 -04:00
if self . _intervention_handler is not None :
2024-06-17 17:34:56 -07:00
try :
2024-06-24 16:52:09 -04:00
temp_message = await self . _intervention_handler . on_response (
message , sender = sender , recipient = recipient
)
2024-06-17 17:34:56 -07:00
except BaseException as e :
# TODO: should we raise the exception to sender of the response instead?
future . set_exception ( e )
return
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance ( temp_message , DropMessage ) :
future . set_exception ( MessageDroppedException ( ) )
return
2024-05-23 16:00:05 -04:00
message_envelope . message = temp_message
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . increment ( )
2024-07-01 11:53:45 -04:00
task = asyncio . create_task ( self . _process_response ( message_envelope ) )
self . _background_tasks . add ( task )
task . add_done_callback ( self . _background_tasks . discard )
2024-05-20 17:30:45 -06:00
2024-05-17 14:59:00 -07:00
# Yield control to the message loop to allow other tasks to run
2024-05-15 12:31:13 -04:00
await asyncio . sleep ( 0 )
2024-06-17 10:44:46 -04:00
2024-07-01 11:53:45 -04:00
@property
def idle ( self ) - > bool :
return len ( self . _message_queue ) == 0 and self . _outstanding_tasks . get ( ) == 0
2024-06-27 11:46:06 -07:00
2024-07-01 11:53:45 -04:00
def start ( self ) - > RunContext :
return RunContext ( self )
2024-06-27 11:46:06 -07:00
2024-07-23 11:49:38 -07:00
async def agent_metadata ( self , agent : AgentId ) - > AgentMetadata :
return ( await self . _get_agent ( agent ) ) . metadata
2024-06-17 15:37:46 -04:00
2024-07-23 11:49:38 -07:00
async def agent_save_state ( self , agent : AgentId ) - > Mapping [ str , Any ] :
return ( await self . _get_agent ( agent ) ) . save_state ( )
2024-06-17 12:43:51 -04:00
2024-07-23 11:49:38 -07:00
async def agent_load_state ( self , agent : AgentId , state : Mapping [ str , Any ] ) - > None :
( await self . _get_agent ( agent ) ) . load_state ( state )
2024-06-17 15:37:46 -04:00
2024-07-23 11:49:38 -07:00
async def register (
2024-06-18 14:53:18 -04:00
self ,
name : str ,
2024-07-23 11:49:38 -07:00
agent_factory : Callable [ [ ] , T | Awaitable [ T ] ] | Callable [ [ AgentRuntime , AgentId ] , T | Awaitable [ T ] ] ,
2024-06-18 14:53:18 -04:00
) - > None :
if name in self . _agent_factories :
raise ValueError ( f " Agent with name { name } already exists. " )
self . _agent_factories [ name ] = agent_factory
2024-06-19 10:49:08 -04:00
# For all already prepared namespaces we need to prepare this agent
for namespace in self . _known_namespaces :
2024-08-07 13:25:44 -04:00
await self . _get_agent ( AgentId ( type = name , key = namespace ) )
2024-06-19 10:49:08 -04:00
2024-07-23 11:49:38 -07:00
async def _invoke_agent_factory (
self ,
agent_factory : Callable [ [ ] , T | Awaitable [ T ] ] | Callable [ [ AgentRuntime , AgentId ] , T | Awaitable [ T ] ] ,
agent_id : AgentId ,
2024-06-18 14:53:18 -04:00
) - > T :
2024-08-02 11:02:45 -04:00
with AgentInstantiationContext . populate_context ( ( self , agent_id ) ) :
2024-07-23 11:49:38 -07:00
if len ( inspect . signature ( agent_factory ) . parameters ) == 0 :
factory_one = cast ( Callable [ [ ] , T ] , agent_factory )
agent = factory_one ( )
elif len ( inspect . signature ( agent_factory ) . parameters ) == 2 :
2024-08-02 11:02:45 -04:00
warnings . warn (
" Agent factories that take two arguments are deprecated. Use AgentInstantiationContext instead. Two arg factories will be removed in a future version. " ,
stacklevel = 2 ,
)
2024-07-23 11:49:38 -07:00
factory_two = cast ( Callable [ [ AgentRuntime , AgentId ] , T ] , agent_factory )
agent = factory_two ( self , agent_id )
else :
raise ValueError ( " Agent factory must take 0 or 2 arguments. " )
if inspect . isawaitable ( agent ) :
return cast ( T , await agent )
return agent
async def _get_agent ( self , agent_id : AgentId ) - > Agent :
2024-08-07 13:25:44 -04:00
await self . _process_seen_namespace ( agent_id . key )
2024-06-18 14:53:18 -04:00
if agent_id in self . _instantiated_agents :
return self . _instantiated_agents [ agent_id ]
2024-08-07 13:25:44 -04:00
if agent_id . type not in self . _agent_factories :
raise LookupError ( f " Agent with name { agent_id . type } not found. " )
2024-06-18 14:53:18 -04:00
2024-08-07 13:25:44 -04:00
agent_factory = self . _agent_factories [ agent_id . type ]
2024-06-18 14:53:18 -04:00
2024-07-23 11:49:38 -07:00
agent = await self . _invoke_agent_factory ( agent_factory , agent_id )
2024-06-18 14:53:18 -04:00
for message_type in agent . metadata [ " subscriptions " ] :
2024-08-07 13:25:44 -04:00
self . _per_type_subscribers [ ( agent_id . key , message_type ) ] . add ( agent_id )
2024-06-18 14:53:18 -04:00
self . _instantiated_agents [ agent_id ] = agent
return agent
2024-07-23 11:49:38 -07:00
async def get ( self , name : str , * , namespace : str = " default " ) - > AgentId :
2024-08-07 13:25:44 -04:00
return ( await self . _get_agent ( AgentId ( type = name , key = namespace ) ) ) . id
2024-06-18 14:53:18 -04:00
2024-07-23 11:49:38 -07:00
async def get_proxy ( self , name : str , * , namespace : str = " default " ) - > AgentProxy :
id = await self . get ( name , namespace = namespace )
2024-06-18 14:53:18 -04:00
return AgentProxy ( id , self )
2024-07-23 16:38:37 -07:00
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance ( self , id : AgentId , type : Type [ T ] = Agent ) - > T : # type: ignore[assignment]
2024-08-07 13:25:44 -04:00
if id . type not in self . _agent_factories :
raise LookupError ( f " Agent with name { id . type } not found. " )
2024-07-23 16:38:37 -07:00
# TODO: check if remote
agent_instance = await self . _get_agent ( id )
if not isinstance ( agent_instance , type ) :
2024-08-07 13:25:44 -04:00
raise TypeError ( f " Agent with name { id . type } is not of type { type . __name__ } " )
2024-07-23 16:38:37 -07:00
return agent_instance
2024-06-18 14:53:18 -04:00
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
2024-07-23 11:49:38 -07:00
async def _process_seen_namespace ( self , namespace : str ) - > None :
2024-06-19 10:49:08 -04:00
if namespace in self . _known_namespaces :
return
self . _known_namespaces . add ( namespace )
2024-06-18 14:53:18 -04:00
for name in self . _known_agent_names :
2024-08-07 13:25:44 -04:00
await self . _get_agent ( AgentId ( type = name , key = namespace ) )