2024-05-15 12:31:13 -04:00
import asyncio
2024-06-18 14:53:18 -04:00
import inspect
2024-06-04 10:17:04 -04:00
import logging
2024-06-18 14:53:18 -04:00
import threading
2024-05-15 12:31:13 -04:00
from asyncio import Future
2024-06-18 14:53:18 -04:00
from collections import defaultdict
2024-06-04 10:17:04 -04:00
from collections . abc import Sequence
2024-05-15 12:31:13 -04:00
from dataclasses import dataclass
2024-06-18 14:53:18 -04:00
from typing import Any , Awaitable , Callable , DefaultDict , Dict , List , Mapping , ParamSpec , Set , Type , TypeVar , cast
2024-05-15 12:31:13 -04:00
2024-06-18 14:53:18 -04:00
from . . core import Agent , AgentId , AgentMetadata , AgentProxy , AgentRuntime , AllNamespaces , BaseAgent , CancellationToken
2024-05-27 17:10:56 -04:00
from . . core . exceptions import MessageDroppedException
from . . core . intervention import DropMessage , InterventionHandler
2024-05-15 12:31:13 -04:00
2024-06-04 10:17:04 -04:00
logger = logging . getLogger ( " agnext " )
2024-06-07 13:33:51 -07:00
event_logger = logging . getLogger ( " agnext.events " )
2024-06-04 10:17:04 -04:00
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass ( kw_only = True )
2024-05-26 08:45:02 -04:00
class PublishMessageEnvelope :
""" A message envelope for publishing messages to all agents that can handle
2024-05-17 14:59:00 -07:00
the message of the type T . """
2024-05-23 16:00:05 -04:00
message : Any
2024-05-20 13:32:08 -06:00
cancellation_token : CancellationToken
2024-06-18 14:53:18 -04:00
sender : AgentId | None
namespace : str
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass ( kw_only = True )
2024-05-23 16:00:05 -04:00
class SendMessageEnvelope :
2024-05-17 14:59:00 -07:00
""" A message envelope for sending a message to a specific agent that can handle
the message of the type T . """
2024-05-23 16:00:05 -04:00
message : Any
2024-06-18 14:53:18 -04:00
sender : AgentId | None
recipient : AgentId
2024-05-26 08:45:02 -04:00
future : Future [ Any ]
2024-05-20 13:32:08 -06:00
cancellation_token : CancellationToken
2024-05-15 12:31:13 -04:00
2024-05-20 17:30:45 -06:00
@dataclass ( kw_only = True )
2024-05-23 16:00:05 -04:00
class ResponseMessageEnvelope :
2024-05-17 14:59:00 -07:00
""" A message envelope for sending a response to a message. """
2024-05-23 16:00:05 -04:00
message : Any
future : Future [ Any ]
2024-06-18 14:53:18 -04:00
sender : AgentId
recipient : AgentId | None
P = ParamSpec ( " P " )
T = TypeVar ( " T " , bound = Agent )
class Counter :
def __init__ ( self ) - > None :
self . _count : int = 0
self . threadLock = threading . Lock ( )
def increment ( self ) - > None :
self . threadLock . acquire ( )
self . _count + = 1
self . threadLock . release ( )
def get ( self ) - > int :
return self . _count
def decrement ( self ) - > None :
self . threadLock . acquire ( )
self . _count - = 1
self . threadLock . release ( )
2024-05-19 17:12:49 -06:00
2024-05-23 16:00:05 -04:00
class SingleThreadedAgentRuntime ( AgentRuntime ) :
def __init__ ( self , * , before_send : InterventionHandler | None = None ) - > None :
2024-05-26 08:45:02 -04:00
self . _message_queue : List [ PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope ] = [ ]
2024-06-18 14:53:18 -04:00
# (namespace, type) -> List[AgentId]
self . _per_type_subscribers : DefaultDict [ tuple [ str , type ] , Set [ AgentId ] ] = defaultdict ( set )
self . _agent_factories : Dict [ str , Callable [ [ ] , Agent ] | Callable [ [ AgentRuntime , AgentId ] , Agent ] ] = { }
# If empty, then all namespaces are valid for that agent type
self . _valid_namespaces : Dict [ str , Sequence [ str ] ] = { }
self . _instantiated_agents : Dict [ AgentId , Agent ] = { }
2024-05-20 17:30:45 -06:00
self . _before_send = before_send
2024-06-18 14:53:18 -04:00
self . _known_namespaces : set [ str ] = set ( )
self . _outstanding_tasks = Counter ( )
2024-06-04 10:17:04 -04:00
@property
2024-06-07 13:33:51 -07:00
def unprocessed_messages (
self ,
) - > Sequence [ PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope ] :
2024-06-04 10:17:04 -04:00
return self . _message_queue
2024-06-18 14:53:18 -04:00
@property
def outstanding_tasks ( self ) - > int :
return self . _outstanding_tasks . get ( )
@property
def _known_agent_names ( self ) - > Set [ str ] :
return set ( self . _agent_factories . keys ( ) )
2024-05-15 12:31:13 -04:00
# Returns the response of the message
2024-05-20 13:32:08 -06:00
def send_message (
2024-05-20 17:30:45 -06:00
self ,
2024-05-23 16:00:05 -04:00
message : Any ,
2024-06-18 14:53:18 -04:00
recipient : AgentId ,
2024-05-20 17:30:45 -06:00
* ,
2024-06-18 14:53:18 -04:00
sender : AgentId | None = None ,
2024-05-20 17:30:45 -06:00
cancellation_token : CancellationToken | None = None ,
2024-05-23 16:00:05 -04:00
) - > Future [ Any | None ] :
2024-05-20 13:32:08 -06:00
if cancellation_token is None :
cancellation_token = CancellationToken ( )
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.SEND,
# )
# )
2024-06-18 14:53:18 -04:00
if recipient . namespace not in self . _known_namespaces :
self . _prepare_namespace ( recipient . namespace )
2024-06-17 15:37:46 -04:00
2024-05-23 16:00:05 -04:00
future = asyncio . get_event_loop ( ) . create_future ( )
2024-06-18 14:53:18 -04:00
if recipient . name not in self . _known_agent_names :
2024-05-23 16:00:05 -04:00
future . set_exception ( Exception ( " Recipient not found " ) )
2024-05-15 12:31:13 -04:00
2024-06-18 14:53:18 -04:00
if sender is not None and sender . namespace != recipient . namespace :
raise ValueError ( " Sender and recipient must be in the same namespace to communicate. " )
logger . info ( f " Sending message of type { type ( message ) . __name__ } to { recipient . name } : { message . __dict__ } " )
2024-05-20 17:30:45 -06:00
self . _message_queue . append (
SendMessageEnvelope (
message = message ,
recipient = recipient ,
future = future ,
cancellation_token = cancellation_token ,
sender = sender ,
)
)
2024-05-20 13:32:08 -06:00
2024-05-15 12:31:13 -04:00
return future
2024-05-26 08:45:02 -04:00
def publish_message (
2024-05-23 16:00:05 -04:00
self ,
message : Any ,
* ,
2024-06-18 14:53:18 -04:00
namespace : str | None = None ,
sender : AgentId | None = None ,
2024-05-23 16:00:05 -04:00
cancellation_token : CancellationToken | None = None ,
2024-05-26 08:45:02 -04:00
) - > Future [ None ] :
2024-05-20 13:32:08 -06:00
if cancellation_token is None :
cancellation_token = CancellationToken ( )
2024-06-07 13:33:51 -07:00
logger . info ( f " Publishing message of type { type ( message ) . __name__ } to all subscribers: { message . __dict__ } " )
# event_logger.info(
# MessageEvent(
# payload=message,
# sender=sender,
# receiver=None,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.SEND,
# )
# )
2024-06-18 14:53:18 -04:00
if sender is None and namespace is None :
raise ValueError ( " Namespace must be provided if sender is not provided. " )
sender_namespace = sender . namespace if sender is not None else None
explicit_namespace = namespace
if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace :
raise ValueError (
f " Explicit namespace { explicit_namespace } does not match sender namespace { sender_namespace } "
)
assert explicit_namespace is not None or sender_namespace is not None
namespace = cast ( str , explicit_namespace or sender_namespace )
if namespace not in self . _known_namespaces :
self . _prepare_namespace ( namespace )
2024-05-20 17:30:45 -06:00
self . _message_queue . append (
2024-05-26 08:45:02 -04:00
PublishMessageEnvelope (
2024-05-23 16:00:05 -04:00
message = message ,
cancellation_token = cancellation_token ,
sender = sender ,
2024-06-18 14:53:18 -04:00
namespace = namespace ,
2024-05-20 17:30:45 -06:00
)
)
2024-05-23 16:00:05 -04:00
2024-05-26 08:45:02 -04:00
future = asyncio . get_event_loop ( ) . create_future ( )
future . set_result ( None )
2024-05-15 12:31:13 -04:00
return future
2024-05-27 20:25:25 -04:00
def save_state ( self ) - > Mapping [ str , Any ] :
state : Dict [ str , Dict [ str , Any ] ] = { }
2024-06-18 14:53:18 -04:00
for agent_id in self . _instantiated_agents :
state [ str ( agent_id ) ] = dict ( self . _get_agent ( agent_id ) . save_state ( ) )
2024-05-27 20:25:25 -04:00
return state
def load_state ( self , state : Mapping [ str , Any ] ) - > None :
2024-06-18 14:53:18 -04:00
for agent_id_str in state :
agent_id = AgentId . from_str ( agent_id_str )
if agent_id . name in self . _known_agent_names :
self . _get_agent ( agent_id ) . load_state ( state [ str ( agent_id ) ] )
2024-05-27 20:25:25 -04:00
2024-05-23 16:00:05 -04:00
async def _process_send ( self , message_envelope : SendMessageEnvelope ) - > None :
2024-05-20 17:30:45 -06:00
recipient = message_envelope . recipient
2024-06-18 14:53:18 -04:00
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
2024-05-15 12:31:13 -04:00
2024-05-20 13:32:08 -06:00
try :
2024-06-18 14:53:18 -04:00
sender_name = message_envelope . sender . name if message_envelope . sender is not None else " Unknown "
2024-06-04 10:17:04 -04:00
logger . info (
2024-06-18 14:53:18 -04:00
f " Calling message handler for { recipient } with message type { type ( message_envelope . message ) . __name__ } sent by { sender_name } "
2024-06-04 10:17:04 -04:00
)
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=recipient,
# kind=MessageKind.DIRECT,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
2024-06-18 14:53:18 -04:00
recipient_agent = self . _get_agent ( recipient )
response = await recipient_agent . on_message (
2024-05-23 16:00:05 -04:00
message_envelope . message ,
cancellation_token = message_envelope . cancellation_token ,
2024-05-20 13:32:08 -06:00
)
except BaseException as e :
message_envelope . future . set_exception ( e )
return
2024-05-26 08:45:02 -04:00
self . _message_queue . append (
ResponseMessageEnvelope (
message = response ,
future = message_envelope . future ,
sender = message_envelope . recipient ,
recipient = message_envelope . sender ,
2024-05-23 16:00:05 -04:00
)
2024-05-26 08:45:02 -04:00
)
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . decrement ( )
2024-05-23 16:00:05 -04:00
2024-05-26 08:45:02 -04:00
async def _process_publish ( self , message_envelope : PublishMessageEnvelope ) - > None :
2024-05-23 16:00:05 -04:00
responses : List [ Awaitable [ Any ] ] = [ ]
2024-06-18 14:53:18 -04:00
target_namespace = message_envelope . namespace
for agent_id in self . _per_type_subscribers [ ( target_namespace , type ( message_envelope . message ) ) ] :
if message_envelope . sender is not None and agent_id . name == message_envelope . sender . name :
2024-05-28 16:21:40 -04:00
continue
2024-06-04 10:17:04 -04:00
2024-06-18 14:53:18 -04:00
sender_agent = self . _get_agent ( message_envelope . sender ) if message_envelope . sender is not None else None
sender_name = sender_agent . metadata [ " name " ] if sender_agent is not None else " Unknown "
2024-06-04 10:17:04 -04:00
logger . info (
2024-06-18 14:53:18 -04:00
f " Calling message handler for { agent_id . name } with message type { type ( message_envelope . message ) . __name__ } published by { sender_name } "
2024-06-04 10:17:04 -04:00
)
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=agent,
# kind=MessageKind.PUBLISH,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
2024-06-04 10:17:04 -04:00
2024-06-18 14:53:18 -04:00
agent = self . _get_agent ( agent_id )
2024-05-23 16:00:05 -04:00
future = agent . on_message (
message_envelope . message ,
cancellation_token = message_envelope . cancellation_token ,
)
2024-05-19 17:12:49 -06:00
responses . append ( future )
2024-05-20 13:32:08 -06:00
try :
2024-05-26 08:45:02 -04:00
_all_responses = await asyncio . gather ( * responses )
except BaseException :
2024-06-04 10:17:04 -04:00
logger . error ( " Error processing publish message " , exc_info = True )
2024-05-20 13:32:08 -06:00
return
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . decrement ( )
2024-05-26 08:45:02 -04:00
# TODO if responses are given for a publish
2024-05-19 17:12:49 -06:00
2024-05-23 16:00:05 -04:00
async def _process_response ( self , message_envelope : ResponseMessageEnvelope ) - > None :
2024-06-07 13:33:51 -07:00
content = (
message_envelope . message . __dict__
if hasattr ( message_envelope . message , " __dict__ " )
else message_envelope . message
)
2024-06-04 10:17:04 -04:00
logger . info (
2024-06-18 14:53:18 -04:00
f " Resolving response with message type { type ( message_envelope . message ) . __name__ } for recipient { message_envelope . recipient } from { message_envelope . sender . name } : { content } "
2024-06-04 10:17:04 -04:00
)
2024-06-07 13:33:51 -07:00
# event_logger.info(
# MessageEvent(
# payload=message_envelope.message,
# sender=message_envelope.sender,
# receiver=message_envelope.recipient,
# kind=MessageKind.RESPOND,
# delivery_stage=DeliveryStage.DELIVER,
# )
# )
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . decrement ( )
2024-05-19 17:12:49 -06:00
message_envelope . future . set_result ( message_envelope . message )
2024-05-15 12:31:13 -04:00
async def process_next ( self ) - > None :
2024-05-17 14:59:00 -07:00
if len ( self . _message_queue ) == 0 :
2024-05-15 12:31:13 -04:00
# Yield control to the event loop to allow other tasks to run
await asyncio . sleep ( 0 )
return
2024-05-17 14:59:00 -07:00
message_envelope = self . _message_queue . pop ( 0 )
2024-05-15 12:31:13 -04:00
2024-05-17 14:59:00 -07:00
match message_envelope :
2024-05-20 17:30:45 -06:00
case SendMessageEnvelope ( message = message , sender = sender , recipient = recipient , future = future ) :
if self . _before_send is not None :
2024-06-17 17:34:56 -07:00
try :
temp_message = await self . _before_send . on_send ( message , sender = sender , recipient = recipient )
except BaseException as e :
future . set_exception ( e )
return
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance ( temp_message , DropMessage ) :
future . set_exception ( MessageDroppedException ( ) )
return
2024-05-23 16:00:05 -04:00
message_envelope . message = temp_message
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . increment ( )
2024-05-20 17:30:45 -06:00
asyncio . create_task ( self . _process_send ( message_envelope ) )
2024-05-26 08:45:02 -04:00
case PublishMessageEnvelope (
2024-05-20 17:30:45 -06:00
message = message ,
sender = sender ,
) :
if self . _before_send is not None :
2024-06-17 17:34:56 -07:00
try :
temp_message = await self . _before_send . on_publish ( message , sender = sender )
except BaseException as e :
# TODO: we should raise the intervention exception to the publisher.
logger . error ( f " Exception raised in in intervention handler: { e } " , exc_info = True )
return
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance ( temp_message , DropMessage ) :
2024-05-26 08:45:02 -04:00
# TODO log message dropped
2024-05-20 17:30:45 -06:00
return
2024-05-23 16:00:05 -04:00
message_envelope . message = temp_message
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . increment ( )
2024-05-26 08:45:02 -04:00
asyncio . create_task ( self . _process_publish ( message_envelope ) )
2024-05-20 17:30:45 -06:00
case ResponseMessageEnvelope ( message = message , sender = sender , recipient = recipient , future = future ) :
if self . _before_send is not None :
2024-06-17 17:34:56 -07:00
try :
temp_message = await self . _before_send . on_response ( message , sender = sender , recipient = recipient )
except BaseException as e :
# TODO: should we raise the exception to sender of the response instead?
future . set_exception ( e )
return
2024-05-20 17:30:45 -06:00
if temp_message is DropMessage or isinstance ( temp_message , DropMessage ) :
future . set_exception ( MessageDroppedException ( ) )
return
2024-05-23 16:00:05 -04:00
message_envelope . message = temp_message
2024-06-18 14:53:18 -04:00
self . _outstanding_tasks . increment ( )
2024-05-20 17:30:45 -06:00
asyncio . create_task ( self . _process_response ( message_envelope ) )
2024-05-17 14:59:00 -07:00
# Yield control to the message loop to allow other tasks to run
2024-05-15 12:31:13 -04:00
await asyncio . sleep ( 0 )
2024-06-17 10:44:46 -04:00
2024-06-18 14:53:18 -04:00
def agent_metadata ( self , agent : AgentId ) - > AgentMetadata :
2024-06-17 15:37:46 -04:00
return self . _get_agent ( agent ) . metadata
2024-06-18 14:53:18 -04:00
def agent_save_state ( self , agent : AgentId ) - > Mapping [ str , Any ] :
2024-06-17 15:37:46 -04:00
return self . _get_agent ( agent ) . save_state ( )
2024-06-17 12:43:51 -04:00
2024-06-18 14:53:18 -04:00
def agent_load_state ( self , agent : AgentId , state : Mapping [ str , Any ] ) - > None :
2024-06-17 15:37:46 -04:00
self . _get_agent ( agent ) . load_state ( state )
2024-06-18 14:53:18 -04:00
def register (
self ,
name : str ,
agent_factory : Callable [ [ ] , T ] | Callable [ [ AgentRuntime , AgentId ] , T ] ,
* ,
valid_namespaces : Sequence [ str ] | Type [ AllNamespaces ] = AllNamespaces ,
) - > None :
if name in self . _agent_factories :
raise ValueError ( f " Agent with name { name } already exists. " )
self . _agent_factories [ name ] = agent_factory
if valid_namespaces is not AllNamespaces :
self . _valid_namespaces [ name ] = cast ( Sequence [ str ] , valid_namespaces )
else :
self . _valid_namespaces [ name ] = [ ]
def _invoke_agent_factory (
self , agent_factory : Callable [ [ ] , T ] | Callable [ [ AgentRuntime , AgentId ] , T ] , agent_id : AgentId
) - > T :
if len ( inspect . signature ( agent_factory ) . parameters ) == 0 :
factory_one = cast ( Callable [ [ ] , T ] , agent_factory )
agent = factory_one ( )
elif len ( inspect . signature ( agent_factory ) . parameters ) == 2 :
factory_two = cast ( Callable [ [ AgentRuntime , AgentId ] , T ] , agent_factory )
agent = factory_two ( self , agent_id )
else :
raise ValueError ( " Agent factory must take 0 or 2 arguments. " )
# TODO: should this be part of the base agent interface?
if isinstance ( agent , BaseAgent ) :
agent . bind_id ( agent_id )
agent . bind_runtime ( self )
return agent
def _type_valid_for_namespace ( self , agent_id : AgentId ) - > bool :
if agent_id . name not in self . _agent_factories :
raise KeyError ( f " Agent with name { agent_id . name } not found. " )
valid_namespaces = self . _valid_namespaces [ agent_id . name ]
if len ( valid_namespaces ) == 0 :
return True
return agent_id . namespace in valid_namespaces
def _get_agent ( self , agent_id : AgentId ) - > Agent :
if agent_id in self . _instantiated_agents :
return self . _instantiated_agents [ agent_id ]
if not self . _type_valid_for_namespace ( agent_id ) :
raise ValueError ( f " Agent with name { agent_id . name } not valid for namespace { agent_id . namespace } . " )
self . _known_namespaces . add ( agent_id . namespace )
if agent_id . name not in self . _agent_factories :
raise ValueError ( f " Agent with name { agent_id . name } not found. " )
agent_factory = self . _agent_factories [ agent_id . name ]
agent = self . _invoke_agent_factory ( agent_factory , agent_id )
for message_type in agent . metadata [ " subscriptions " ] :
self . _per_type_subscribers [ ( agent_id . namespace , message_type ) ] . add ( agent_id )
self . _instantiated_agents [ agent_id ] = agent
return agent
def get ( self , name : str , * , namespace : str = " default " ) - > AgentId :
return self . _get_agent ( AgentId ( name = name , namespace = namespace ) ) . id
def get_proxy ( self , name : str , * , namespace : str = " default " ) - > AgentProxy :
id = self . get ( name , namespace = namespace )
return AgentProxy ( id , self )
# Hydrate the agent instances in a namespace. The primary reason for this is
# to ensure message type subscriptions are set up.
def _prepare_namespace ( self , namespace : str ) - > None :
for name in self . _known_agent_names :
if self . _type_valid_for_namespace ( AgentId ( name = name , namespace = namespace ) ) :
self . _get_agent ( AgentId ( name = name , namespace = namespace ) )