Update proto to include remove sub, move to rpc based operations (#5168)

* Update proto to include remove sub, move to rpc based operations

* dont add a breaking change

* mypy fix
This commit is contained in:
Jack Gerrits 2025-01-23 17:46:47 -05:00 committed by GitHub
parent c3e84dc5ca
commit 44b9bff466
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 368 additions and 32 deletions

View File

@ -48,12 +48,12 @@ message Event {
}
message RegisterAgentTypeRequest {
string request_id = 1;
string request_id = 1; // TODO: remove once message based requests are removed
string type = 2;
}
message RegisterAgentTypeResponse {
string request_id = 1;
string request_id = 1; // TODO: remove once message based requests are removed
bool success = 2;
optional string error = 3;
}
@ -69,27 +69,46 @@ message TypePrefixSubscription {
}
message Subscription {
string id = 1;
oneof subscription {
TypeSubscription typeSubscription = 1;
TypePrefixSubscription typePrefixSubscription = 2;
TypeSubscription typeSubscription = 2;
TypePrefixSubscription typePrefixSubscription = 3;
}
}
message AddSubscriptionRequest {
string request_id = 1;
string request_id = 1; // TODO: remove once message based requests are removed
Subscription subscription = 2;
}
message AddSubscriptionResponse {
string request_id = 1;
string request_id = 1; // TODO: remove once message based requests are removed
bool success = 2;
optional string error = 3;
}
message RemoveSubscriptionRequest {
string id = 1;
}
message RemoveSubscriptionResponse {
bool success = 1;
optional string error = 2;
}
message GetSubscriptionsRequest {}
message GetSubscriptionsResponse {
repeated Subscription subscriptions = 1;
}
service AgentRpc {
rpc OpenChannel (stream Message) returns (stream Message);
rpc GetState(AgentId) returns (GetStateResponse);
rpc SaveState(AgentState) returns (SaveStateResponse);
rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse);
rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse);
rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse);
rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse);
}
message AgentState {

View File

@ -31,13 +31,13 @@ class TypePrefixSubscription(Subscription):
agent_type (str): Agent type to handle this subscription
"""
def __init__(self, topic_type_prefix: str, agent_type: str | AgentType):
def __init__(self, topic_type_prefix: str, agent_type: str | AgentType, id: str | None = None):
self._topic_type_prefix = topic_type_prefix
if isinstance(agent_type, AgentType):
self._agent_type = agent_type.type
else:
self._agent_type = agent_type
self._id = str(uuid.uuid4())
self._id = id or str(uuid.uuid4())
@property
def id(self) -> str:

View File

@ -30,13 +30,13 @@ class TypeSubscription(Subscription):
agent_type (str): Agent type to handle this subscription
"""
def __init__(self, topic_type: str, agent_type: str | AgentType):
def __init__(self, topic_type: str, agent_type: str | AgentType, id: str | None = None):
self._topic_type = topic_type
if isinstance(agent_type, AgentType):
self._agent_type = agent_type.type
else:
self._agent_type = agent_type
self._id = str(uuid.uuid4())
self._id = id or str(uuid.uuid4())
@property
def id(self) -> str:

View File

@ -807,25 +807,27 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
request_id = await self._get_new_request_id()
match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type):
case TypeSubscription(topic_type=topic_type, agent_type=agent_type, id=id):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
id=id,
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=topic_type, agent_type=agent_type
)
),
),
)
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type):
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type, id=id):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
id=id,
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
)
),
),
)
)

View File

@ -227,7 +227,9 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
add_subscription_req.subscription.typeSubscription
)
subscription = TypeSubscription(
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
topic_type=type_subscription_msg.topic_type,
agent_type=type_subscription_msg.agent_type,
id=add_subscription_req.subscription.id,
)
case "typePrefixSubscription":
@ -237,6 +239,7 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
subscription = TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
id=add_subscription_req.subscription.id,
)
case None:
logger.warning("Received empty subscription message")
@ -260,6 +263,42 @@ class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer)
)
)
def RegisterAgent( # type: ignore
self,
request: agent_worker_pb2.RegisterAgentTypeRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.RegisterAgentTypeRequest, agent_worker_pb2.RegisterAgentTypeResponse
],
) -> agent_worker_pb2.RegisterAgentTypeResponse:
raise NotImplementedError("Method not implemented.")
def AddSubscription( # type: ignore
self,
request: agent_worker_pb2.AddSubscriptionRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.AddSubscriptionRequest, agent_worker_pb2.AddSubscriptionResponse
],
) -> agent_worker_pb2.AddSubscriptionResponse:
raise NotImplementedError("Method not implemented.")
def RemoveSubscription( # type: ignore
self,
request: agent_worker_pb2.RemoveSubscriptionRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.RemoveSubscriptionRequest, agent_worker_pb2.RemoveSubscriptionResponse
],
) -> agent_worker_pb2.RemoveSubscriptionResponse:
raise NotImplementedError("Method not implemented.")
def GetSubscriptions( # type: ignore
self,
request: agent_worker_pb2.GetSubscriptionsRequest,
context: grpc.aio.ServicerContext[
agent_worker_pb2.GetSubscriptionsRequest, agent_worker_pb2.GetSubscriptionsResponse
],
) -> agent_worker_pb2.GetSubscriptionsResponse:
raise NotImplementedError("Method not implemented.")
async def GetState( # type: ignore
self,
request: agent_worker_pb2.AgentId,

File diff suppressed because one or more lines are too long

View File

@ -221,6 +221,7 @@ class RegisterAgentTypeRequest(google.protobuf.message.Message):
REQUEST_ID_FIELD_NUMBER: builtins.int
TYPE_FIELD_NUMBER: builtins.int
request_id: builtins.str
"""TODO: remove once message based requests are removed"""
type: builtins.str
def __init__(
self,
@ -240,6 +241,7 @@ class RegisterAgentTypeResponse(google.protobuf.message.Message):
SUCCESS_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
request_id: builtins.str
"""TODO: remove once message based requests are removed"""
success: builtins.bool
error: builtins.str
def __init__(
@ -295,8 +297,10 @@ global___TypePrefixSubscription = TypePrefixSubscription
class Subscription(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
TYPESUBSCRIPTION_FIELD_NUMBER: builtins.int
TYPEPREFIXSUBSCRIPTION_FIELD_NUMBER: builtins.int
id: builtins.str
@property
def typeSubscription(self) -> global___TypeSubscription: ...
@property
@ -304,11 +308,12 @@ class Subscription(google.protobuf.message.Message):
def __init__(
self,
*,
id: builtins.str = ...,
typeSubscription: global___TypeSubscription | None = ...,
typePrefixSubscription: global___TypePrefixSubscription | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ...
def ClearField(self, field_name: typing.Literal["id", b"id", "subscription", b"subscription", "typePrefixSubscription", b"typePrefixSubscription", "typeSubscription", b"typeSubscription"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["subscription", b"subscription"]) -> typing.Literal["typeSubscription", "typePrefixSubscription"] | None: ...
global___Subscription = Subscription
@ -320,6 +325,7 @@ class AddSubscriptionRequest(google.protobuf.message.Message):
REQUEST_ID_FIELD_NUMBER: builtins.int
SUBSCRIPTION_FIELD_NUMBER: builtins.int
request_id: builtins.str
"""TODO: remove once message based requests are removed"""
@property
def subscription(self) -> global___Subscription: ...
def __init__(
@ -341,6 +347,7 @@ class AddSubscriptionResponse(google.protobuf.message.Message):
SUCCESS_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
request_id: builtins.str
"""TODO: remove once message based requests are removed"""
success: builtins.bool
error: builtins.str
def __init__(
@ -356,6 +363,67 @@ class AddSubscriptionResponse(google.protobuf.message.Message):
global___AddSubscriptionResponse = AddSubscriptionResponse
@typing.final
class RemoveSubscriptionRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
id: builtins.str
def __init__(
self,
*,
id: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["id", b"id"]) -> None: ...
global___RemoveSubscriptionRequest = RemoveSubscriptionRequest
@typing.final
class RemoveSubscriptionResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SUCCESS_FIELD_NUMBER: builtins.int
ERROR_FIELD_NUMBER: builtins.int
success: builtins.bool
error: builtins.str
def __init__(
self,
*,
success: builtins.bool = ...,
error: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "success", b"success"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ...
global___RemoveSubscriptionResponse = RemoveSubscriptionResponse
@typing.final
class GetSubscriptionsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(
self,
) -> None: ...
global___GetSubscriptionsRequest = GetSubscriptionsRequest
@typing.final
class GetSubscriptionsResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
@property
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Subscription]: ...
def __init__(
self,
*,
subscriptions: collections.abc.Iterable[global___Subscription] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["subscriptions", b"subscriptions"]) -> None: ...
global___GetSubscriptionsResponse = GetSubscriptionsResponse
@typing.final
class AgentState(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

View File

@ -29,6 +29,26 @@ class AgentRpcStub(object):
request_serializer=agent__worker__pb2.AgentState.SerializeToString,
response_deserializer=agent__worker__pb2.SaveStateResponse.FromString,
)
self.RegisterAgent = channel.unary_unary(
'/agents.AgentRpc/RegisterAgent',
request_serializer=agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString,
response_deserializer=agent__worker__pb2.RegisterAgentTypeResponse.FromString,
)
self.AddSubscription = channel.unary_unary(
'/agents.AgentRpc/AddSubscription',
request_serializer=agent__worker__pb2.AddSubscriptionRequest.SerializeToString,
response_deserializer=agent__worker__pb2.AddSubscriptionResponse.FromString,
)
self.RemoveSubscription = channel.unary_unary(
'/agents.AgentRpc/RemoveSubscription',
request_serializer=agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString,
response_deserializer=agent__worker__pb2.RemoveSubscriptionResponse.FromString,
)
self.GetSubscriptions = channel.unary_unary(
'/agents.AgentRpc/GetSubscriptions',
request_serializer=agent__worker__pb2.GetSubscriptionsRequest.SerializeToString,
response_deserializer=agent__worker__pb2.GetSubscriptionsResponse.FromString,
)
class AgentRpcServicer(object):
@ -52,6 +72,30 @@ class AgentRpcServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def RegisterAgent(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def AddSubscription(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def RemoveSubscription(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetSubscriptions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AgentRpcServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -70,6 +114,26 @@ def add_AgentRpcServicer_to_server(servicer, server):
request_deserializer=agent__worker__pb2.AgentState.FromString,
response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString,
),
'RegisterAgent': grpc.unary_unary_rpc_method_handler(
servicer.RegisterAgent,
request_deserializer=agent__worker__pb2.RegisterAgentTypeRequest.FromString,
response_serializer=agent__worker__pb2.RegisterAgentTypeResponse.SerializeToString,
),
'AddSubscription': grpc.unary_unary_rpc_method_handler(
servicer.AddSubscription,
request_deserializer=agent__worker__pb2.AddSubscriptionRequest.FromString,
response_serializer=agent__worker__pb2.AddSubscriptionResponse.SerializeToString,
),
'RemoveSubscription': grpc.unary_unary_rpc_method_handler(
servicer.RemoveSubscription,
request_deserializer=agent__worker__pb2.RemoveSubscriptionRequest.FromString,
response_serializer=agent__worker__pb2.RemoveSubscriptionResponse.SerializeToString,
),
'GetSubscriptions': grpc.unary_unary_rpc_method_handler(
servicer.GetSubscriptions,
request_deserializer=agent__worker__pb2.GetSubscriptionsRequest.FromString,
response_serializer=agent__worker__pb2.GetSubscriptionsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'agents.AgentRpc', rpc_method_handlers)
@ -130,3 +194,71 @@ class AgentRpc(object):
agent__worker__pb2.SaveStateResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def RegisterAgent(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/RegisterAgent',
agent__worker__pb2.RegisterAgentTypeRequest.SerializeToString,
agent__worker__pb2.RegisterAgentTypeResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def AddSubscription(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/AddSubscription',
agent__worker__pb2.AddSubscriptionRequest.SerializeToString,
agent__worker__pb2.AddSubscriptionResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def RemoveSubscription(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/RemoveSubscription',
agent__worker__pb2.RemoveSubscriptionRequest.SerializeToString,
agent__worker__pb2.RemoveSubscriptionResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GetSubscriptions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/agents.AgentRpc/GetSubscriptions',
agent__worker__pb2.GetSubscriptionsRequest.SerializeToString,
agent__worker__pb2.GetSubscriptionsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@ -34,6 +34,26 @@ class AgentRpcStub:
agent_worker_pb2.SaveStateResponse,
]
RegisterAgent: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.RegisterAgentTypeRequest,
agent_worker_pb2.RegisterAgentTypeResponse,
]
AddSubscription: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.AddSubscriptionRequest,
agent_worker_pb2.AddSubscriptionResponse,
]
RemoveSubscription: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.RemoveSubscriptionRequest,
agent_worker_pb2.RemoveSubscriptionResponse,
]
GetSubscriptions: grpc.UnaryUnaryMultiCallable[
agent_worker_pb2.GetSubscriptionsRequest,
agent_worker_pb2.GetSubscriptionsResponse,
]
class AgentRpcAsyncStub:
OpenChannel: grpc.aio.StreamStreamMultiCallable[
agent_worker_pb2.Message,
@ -50,6 +70,26 @@ class AgentRpcAsyncStub:
agent_worker_pb2.SaveStateResponse,
]
RegisterAgent: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.RegisterAgentTypeRequest,
agent_worker_pb2.RegisterAgentTypeResponse,
]
AddSubscription: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.AddSubscriptionRequest,
agent_worker_pb2.AddSubscriptionResponse,
]
RemoveSubscription: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.RemoveSubscriptionRequest,
agent_worker_pb2.RemoveSubscriptionResponse,
]
GetSubscriptions: grpc.aio.UnaryUnaryMultiCallable[
agent_worker_pb2.GetSubscriptionsRequest,
agent_worker_pb2.GetSubscriptionsResponse,
]
class AgentRpcServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def OpenChannel(
@ -72,4 +112,32 @@ class AgentRpcServicer(metaclass=abc.ABCMeta):
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]]: ...
@abc.abstractmethod
def RegisterAgent(
self,
request: agent_worker_pb2.RegisterAgentTypeRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.RegisterAgentTypeResponse, collections.abc.Awaitable[agent_worker_pb2.RegisterAgentTypeResponse]]: ...
@abc.abstractmethod
def AddSubscription(
self,
request: agent_worker_pb2.AddSubscriptionRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.AddSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.AddSubscriptionResponse]]: ...
@abc.abstractmethod
def RemoveSubscription(
self,
request: agent_worker_pb2.RemoveSubscriptionRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.RemoveSubscriptionResponse, collections.abc.Awaitable[agent_worker_pb2.RemoveSubscriptionResponse]]: ...
@abc.abstractmethod
def GetSubscriptions(
self,
request: agent_worker_pb2.GetSubscriptionsRequest,
context: _ServicerContext,
) -> typing.Union[agent_worker_pb2.GetSubscriptionsResponse, collections.abc.Awaitable[agent_worker_pb2.GetSubscriptionsResponse]]: ...
def add_AgentRpcServicer_to_server(servicer: AgentRpcServicer, server: typing.Union[grpc.Server, grpc.aio.Server]) -> None: ...