Added 'request_halt' flag, and fixed an orchestration bug (#174)

* Added 'request_halt' flag, and fixed an orchestration bug

* Fixed formatting errors.

* Fixed a hatch error with casting.
This commit is contained in:
afourney 2024-07-03 00:04:44 -07:00 committed by GitHub
parent 9df928b73e
commit 99ecb5ec7f
4 changed files with 32 additions and 12 deletions

View File

@ -35,9 +35,9 @@ if __name__ == "__main__":
The user cannot provide any feedback or perform any other action beyond executing the code you suggest. In particular, the user can't modify your code, and can't copy and paste anything, and can't fill in missing values. Thus, do not suggest incomplete code which requires users to perform any of these actions.
Check the execution result returned by the user. If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes -- code blocks must stand alone and be ready to execute without modification. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, and think of a different approach to try.
The user will run all code that you provide, and will report back the results. When receiving the results, check if the output indicates an error. Fix the error. When fixing the error, output the full code, as before, instead of partial code or code changes -- code blocks must stand alone and be ready to execute without modification. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, and think of a different approach to try.
If the code has executed successfully, and the problem is stolved, reply "TERMINATE".
If the code was executed, and the output appears to indicate that the original prolem was solved successful, reply "TERMINATE". UNDER NO OTHER CONDITIONS SHOULD "TERMINATE" BE USED.
""")
]
@ -67,12 +67,12 @@ If the code has executed successfully, and the problem is stolved, reply "TERMIN
assert isinstance(response.content, str)
self._chat_history.append(AssistantMessage(content=response.content, source=self.metadata["name"]))
if "TERMINATE" in response.content:
return
else:
await self.publish_message(
BroadcastMessage(content=UserMessage(content=response.content, source=self.metadata["name"]))
await self.publish_message(
BroadcastMessage(
content=UserMessage(content=response.content, source=self.metadata["name"]),
request_halt=("TERMINATE" in response.content),
)
)
class Executor(TypeRoutedAgent):

View File

@ -17,9 +17,11 @@ class RoundRobinOrchestrator(TypeRoutedAgent):
self,
agents: List[AgentProxy],
description: str = "Round robin orchestrator",
max_rounds: int = 20,
) -> None:
super().__init__(description)
self._agents = agents
self._max_rounds = max_rounds
self._num_rounds = 0
@message_handler
@ -34,7 +36,25 @@ class RoundRobinOrchestrator(TypeRoutedAgent):
current_timestamp = datetime.now().isoformat()
logger.info(OrchestrationEvent(current_timestamp, source, content))
if self._num_rounds > 20:
# Termination conditions
if self._num_rounds >= self._max_rounds:
logger.info(
OrchestrationEvent(
current_timestamp,
f"{self.metadata['name']} (termination condition)",
f"Max rounds ({self._max_rounds}) reached.",
)
)
return
if message.request_halt:
logger.info(
OrchestrationEvent(
current_timestamp,
f"{self.metadata['name']} (termination condition)",
f"{source} requested halt.",
)
)
return
next_agent = self._select_next_agent()
@ -45,15 +65,14 @@ class RoundRobinOrchestrator(TypeRoutedAgent):
logger.info(
OrchestrationEvent(
current_timestamp,
source="Orchestrator (thought)",
source=f"{self.metadata['name']} (thought)",
message=f"Next speaker {next_agent.metadata['name']}" "",
)
)
self._num_rounds += 1 # Call before sending the message
await self.send_message(request_reply_message, next_agent.id)
self._num_rounds += 1
def _select_next_agent(self) -> AgentProxy:
self._current_index = (self._num_rounds) % len(self._agents)
return self._agents[self._current_index]

View File

@ -6,6 +6,7 @@ from agnext.components.models import LLMMessage
@dataclass
class BroadcastMessage:
content: LLMMessage
request_halt: bool = False
@dataclass

View File

@ -524,7 +524,7 @@ echo RUN.SH COMPLETE !#!#
while True:
try:
chunk = cast(bytes, next(logs)) # Manually step the iterator so it is captures with the try-catch
chunk = next(logs) # Manually step the iterator so it is captures with the try-catch
# Stream the data to the log file and the console
chunk_str = chunk.decode("utf-8")