TeamOne cancellation token support and making logger a member variable (#622)

* Added cancellation token support for team_one; made logger a member variable of each agent.

formatting

fix error

fix error

formatting

* No need to create a new cancellation token
This commit is contained in:
Enhao Zhang 2024-09-24 16:54:22 -07:00 committed by GitHub
parent 6dcbf869ad
commit afdfb4ea58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 100 additions and 57 deletions

View File

@ -15,8 +15,6 @@ from team_one.messages import (
TeamOneMessages, TeamOneMessages,
) )
logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent")
class TeamOneBaseAgent(RoutedAgent): class TeamOneBaseAgent(RoutedAgent):
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive.""" """An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
@ -29,6 +27,7 @@ class TeamOneBaseAgent(RoutedAgent):
super().__init__(description) super().__init__(description)
self._handle_messages_concurrently = handle_messages_concurrently self._handle_messages_concurrently = handle_messages_concurrently
self._enabled = True self._enabled = True
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.agent")
if not self._handle_messages_concurrently: if not self._handle_messages_concurrently:
# TODO: make it possible to stop # TODO: make it possible to stop
@ -40,6 +39,7 @@ class TeamOneBaseAgent(RoutedAgent):
message, ctx, future = await self._message_queue.get() message, ctx, future = await self._message_queue.get()
if ctx.cancellation_token.is_cancelled(): if ctx.cancellation_token.is_cancelled():
# TODO: Do we need to resolve the future here? # TODO: Do we need to resolve the future here?
future.cancel()
continue continue
try: try:
@ -54,6 +54,8 @@ class TeamOneBaseAgent(RoutedAgent):
else: else:
raise ValueError("Unknown message type.") raise ValueError("Unknown message type.")
future.set_result(None) future.set_result(None)
except asyncio.CancelledError:
future.cancel()
except Exception as e: except Exception as e:
future.set_exception(e) future.set_exception(e)
@ -92,9 +94,19 @@ class TeamOneBaseAgent(RoutedAgent):
async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None: async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None:
"""Handle a deactivate message.""" """Handle a deactivate message."""
self._enabled = False self._enabled = False
logger.info( self.logger.info(
AgentEvent( AgentEvent(
f"{self.metadata['type']} (deactivated)", f"{self.metadata['type']} (deactivated)",
"", "",
) )
) )
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
"""Drop the message, with a log."""
# self.logger.info(
# AgentEvent(
# f"{self.metadata['type']} (unhandled message)",
# f"Unhandled message type: {type(message)}",
# )
# )
pass

View File

@ -10,8 +10,6 @@ from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from ..utils import message_content_to_str from ..utils import message_content_to_str
from .base_agent import TeamOneBaseAgent from .base_agent import TeamOneBaseAgent
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
class BaseOrchestrator(TeamOneBaseAgent): class BaseOrchestrator(TeamOneBaseAgent):
def __init__( def __init__(
@ -28,6 +26,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
self._max_time = max_time self._max_time = max_time
self._num_rounds = 0 self._num_rounds = 0
self._start_time: float = -1.0 self._start_time: float = -1.0
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.orchestrator")
async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None: async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None:
"""Handle an incoming message.""" """Handle an incoming message."""
@ -42,11 +41,11 @@ class BaseOrchestrator(TeamOneBaseAgent):
content = message_content_to_str(message.content.content) content = message_content_to_str(message.content.content)
logger.info(OrchestrationEvent(source, content)) self.logger.info(OrchestrationEvent(source, content))
# Termination conditions # Termination conditions
if self._num_rounds >= self._max_rounds: if self._num_rounds >= self._max_rounds:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (termination condition)", f"{self.metadata['type']} (termination condition)",
f"Max rounds ({self._max_rounds}) reached.", f"Max rounds ({self._max_rounds}) reached.",
@ -55,7 +54,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
return return
if time.time() - self._start_time >= self._max_time: if time.time() - self._start_time >= self._max_time:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (termination condition)", f"{self.metadata['type']} (termination condition)",
f"Max time ({self._max_time}s) reached.", f"Max time ({self._max_time}s) reached.",
@ -64,7 +63,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
return return
if message.request_halt: if message.request_halt:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (termination condition)", f"{self.metadata['type']} (termination condition)",
f"{source} requested halt.", f"{source} requested halt.",
@ -74,7 +73,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
next_agent = await self._select_next_agent(message.content) next_agent = await self._select_next_agent(message.content)
if next_agent is None: if next_agent is None:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (termination condition)", f"{self.metadata['type']} (termination condition)",
"No agent selected.", "No agent selected.",
@ -84,7 +83,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
request_reply_message = RequestReplyMessage() request_reply_message = RequestReplyMessage()
# emit an event # emit an event
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
source=f"{self.metadata['type']} (thought)", source=f"{self.metadata['type']} (thought)",
message=f"Next speaker {(await next_agent.metadata)['type']}" "", message=f"Next speaker {(await next_agent.metadata)['type']}" "",
@ -92,7 +91,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
) )
self._num_rounds += 1 # Call before sending the message self._num_rounds += 1 # Call before sending the message
await self.send_message(request_reply_message, next_agent.id) await self.send_message(request_reply_message, next_agent.id, cancellation_token=ctx.cancellation_token)
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
raise NotImplementedError() raise NotImplementedError()

View File

@ -46,7 +46,11 @@ class BaseWorker(TeamOneBaseAgent):
user_message = UserMessage(content=response, source=self.metadata["type"]) user_message = UserMessage(content=response, source=self.metadata["type"])
topic_id = TopicId("default", self.id.key) topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id) await self.publish_message(
BroadcastMessage(content=user_message, request_halt=request_halt),
topic_id=topic_id,
cancellation_token=ctx.cancellation_token,
)
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]: async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)""" """Returns (request_halt, response_message)"""

View File

@ -49,7 +49,9 @@ Reply "TERMINATE" in the end when everything is done.""")
"""Respond to a reply request.""" """Respond to a reply request."""
# Make an inference to the model. # Make an inference to the model.
response = await self._model_client.create(self._system_messages + self._chat_history) response = await self._model_client.create(
self._system_messages + self._chat_history, cancellation_token=cancellation_token
)
assert isinstance(response.content, str) assert isinstance(response.content, str)
return "TERMINATE" in response.content, response.content return "TERMINATE" in response.content, response.content

View File

@ -90,7 +90,7 @@ class FileSurfer(BaseWorker):
) )
create_result = await self._model_client.create( create_result = await self._model_client.create(
messages=history + [context_message, task_message], tools=self._tools messages=history + [context_message, task_message], tools=self._tools, cancellation_token=cancellation_token
) )
response = create_result.content response = create_result.content

View File

@ -7,7 +7,7 @@ import os
import pathlib import pathlib
import re import re
import traceback import traceback
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callable, Dict, List, Literal, Tuple from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast, Optional # Any, Callable, Dict, List, Literal, Tuple
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse
import aiofiles import aiofiles
@ -67,8 +67,6 @@ MLM_WIDTH = 1224
SCREENSHOT_TOKENS = 1105 SCREENSHOT_TOKENS = 1105
logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer")
# Sentinels # Sentinels
class DEFAULT_CHANNEL(metaclass=SentinelMeta): class DEFAULT_CHANNEL(metaclass=SentinelMeta):
@ -96,6 +94,7 @@ class MultimodalWebSurfer(BaseWorker):
self._page: Page | None = None self._page: Page | None = None
self._last_download: Download | None = None self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None self._prior_metadata_hash: str | None = None
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.MultimodalWebSurfer")
# Read page_script # Read page_script
self._page_script: str = "" self._page_script: str = ""
@ -196,7 +195,7 @@ setInterval(function() {{
""".strip(), """.strip(),
) )
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png")) await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n") self.logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
async def _reset(self, cancellation_token: CancellationToken) -> None: async def _reset(self, cancellation_token: CancellationToken) -> None:
assert self._page is not None assert self._page is not None
@ -205,7 +204,7 @@ setInterval(function() {{
await self._visit_page(self.start_page) await self._visit_page(self.start_page)
if self.debug_dir: if self.debug_dir:
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png")) await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info( self.logger.info(
WebSurferEvent( WebSurferEvent(
source=self.metadata["type"], source=self.metadata["type"],
url=self._page.url, url=self._page.url,
@ -250,13 +249,18 @@ setInterval(function() {{
return False, f"Web surfing error:\n\n{traceback.format_exc()}" return False, f"Web surfing error:\n\n{traceback.format_exc()}"
async def _execute_tool( async def _execute_tool(
self, message: List[FunctionCall], rects: Dict[str, InteractiveRegion], tool_names: str, use_ocr: bool = True self,
message: List[FunctionCall],
rects: Dict[str, InteractiveRegion],
tool_names: str,
use_ocr: bool = True,
cancellation_token: Optional[CancellationToken] = None,
) -> Tuple[bool, UserContent]: ) -> Tuple[bool, UserContent]:
name = message[0].name name = message[0].name
args = json.loads(message[0].arguments) args = json.loads(message[0].arguments)
action_description = "" action_description = ""
assert self._page is not None assert self._page is not None
logger.info( self.logger.info(
WebSurferEvent( WebSurferEvent(
source=self.metadata["type"], source=self.metadata["type"],
url=self._page.url, url=self._page.url,
@ -340,11 +344,11 @@ setInterval(function() {{
elif name == "answer_question": elif name == "answer_question":
question = str(args.get("question")) question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change. # Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question) return False, await self._summarize_page(question=question, cancellation_token=cancellation_token)
elif name == "summarize_page": elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change. # Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page() return False, await self._summarize_page(cancellation_token=cancellation_token)
elif name == "sleep": elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action." action_description = "I am waiting a short period of time before taking further action."
@ -394,7 +398,9 @@ setInterval(function() {{
async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file: async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file:
await file.write(new_screenshot) await file.write(new_screenshot)
ocr_text = await self._get_ocr_text(new_screenshot) if use_ocr is True else "" ocr_text = (
await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) if use_ocr is True else ""
)
# Return the complete observation # Return the complete observation
message_content = "" # message.content or "" message_content = "" # message.content or ""
@ -518,7 +524,7 @@ When deciding between tools, consider if the request can be best addressed by:
UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"]) UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"])
) )
response = await self._model_client.create( response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"} history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False}) ) # , "parallel_tool_calls": False})
message = response.content message = response.content
@ -529,7 +535,7 @@ When deciding between tools, consider if the request can be best addressed by:
return False, message return False, message
elif isinstance(message, list): elif isinstance(message, list):
# Take an action # Take an action
return await self._execute_tool(message, rects, tool_names) return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
else: else:
# Not sure what happened here # Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'") raise AssertionError(f"Unknown response format '{message}'")
@ -668,7 +674,7 @@ When deciding between tools, consider if the request can be best addressed by:
assert isinstance(new_page, Page) assert isinstance(new_page, Page)
await self._on_new_page(new_page) await self._on_new_page(new_page)
logger.info( self.logger.info(
WebSurferEvent( WebSurferEvent(
source=self.metadata["type"], source=self.metadata["type"],
url=self._page.url, url=self._page.url,
@ -716,7 +722,12 @@ When deciding between tools, consider if the request can be best addressed by:
""" """
) )
async def _summarize_page(self, question: str | None = None, token_limit: int = 100000) -> str: async def _summarize_page(
self,
question: str | None = None,
token_limit: int = 100000,
cancellation_token: Optional[CancellationToken] = None,
) -> str:
assert self._page is not None assert self._page is not None
page_markdown: str = await self._get_page_markdown() page_markdown: str = await self._get_page_markdown()
@ -780,12 +791,14 @@ When deciding between tools, consider if the request can be best addressed by:
) )
# Generate the response # Generate the response
response = await self._model_client.create(messages) response = await self._model_client.create(messages, cancellation_token=cancellation_token)
scaled_screenshot.close() scaled_screenshot.close()
assert isinstance(response.content, str) assert isinstance(response.content, str)
return response.content return response.content
async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) -> str: async def _get_ocr_text(
self, image: bytes | io.BufferedIOBase | Image.Image, cancellation_token: Optional[CancellationToken] = None
) -> str:
scaled_screenshot = None scaled_screenshot = None
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT)) scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT))
@ -810,7 +823,7 @@ When deciding between tools, consider if the request can be best addressed by:
source=self.metadata["type"], source=self.metadata["type"],
) )
) )
response = await self._model_client.create(messages) response = await self._model_client.create(messages, cancellation_token=cancellation_token)
scaled_screenshot.close() scaled_screenshot.close()
assert isinstance(response.content, str) assert isinstance(response.content, str)
return response.content return response.content

View File

@ -1,7 +1,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from autogen_core.base import AgentProxy, MessageContext, TopicId from autogen_core.base import AgentProxy, MessageContext, TopicId, CancellationToken
from autogen_core.components import default_subscription from autogen_core.components import default_subscription
from autogen_core.components.models import ( from autogen_core.components.models import (
AssistantMessage, AssistantMessage,
@ -12,7 +12,7 @@ from autogen_core.components.models import (
) )
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
from .base_orchestrator import BaseOrchestrator, logger from .base_orchestrator import BaseOrchestrator
from .orchestrator_prompts import ( from .orchestrator_prompts import (
ORCHESTRATOR_CLOSED_BOOK_PROMPT, ORCHESTRATOR_CLOSED_BOOK_PROMPT,
ORCHESTRATOR_LEDGER_PROMPT, ORCHESTRATOR_LEDGER_PROMPT,
@ -128,7 +128,7 @@ class LedgerOrchestrator(BaseOrchestrator):
assert len(result) > 0 assert len(result) > 0
return result return result
async def _initialize_task(self, task: str) -> None: async def _initialize_task(self, task: str, cancellation_token: Optional[CancellationToken] = None) -> None:
self._task = task self._task = task
self._team_description = await self._get_team_description() self._team_description = await self._get_team_description()
@ -140,7 +140,9 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append( planning_conversation.append(
UserMessage(content=self._get_closed_book_prompt(self._task), source=self.metadata["type"]) UserMessage(content=self._get_closed_book_prompt(self._task), source=self.metadata["type"])
) )
response = await self._model_client.create(self._system_messages + planning_conversation) response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str) assert isinstance(response.content, str)
self._facts = response.content self._facts = response.content
@ -151,14 +153,16 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append( planning_conversation.append(
UserMessage(content=self._get_plan_prompt(self._team_description), source=self.metadata["type"]) UserMessage(content=self._get_plan_prompt(self._team_description), source=self.metadata["type"])
) )
response = await self._model_client.create(self._system_messages + planning_conversation) response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str) assert isinstance(response.content, str)
self._plan = response.content self._plan = response.content
# At this point, the planning conversation is dropped. # At this point, the planning conversation is dropped.
async def _update_facts_and_plan(self) -> None: async def _update_facts_and_plan(self, cancellation_token: Optional[CancellationToken] = None) -> None:
# Shallow-copy the conversation # Shallow-copy the conversation
planning_conversation = [m for m in self._chat_history] planning_conversation = [m for m in self._chat_history]
@ -166,7 +170,9 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append( planning_conversation.append(
UserMessage(content=self._get_update_facts_prompt(self._task, self._facts), source=self.metadata["type"]) UserMessage(content=self._get_update_facts_prompt(self._task, self._facts), source=self.metadata["type"])
) )
response = await self._model_client.create(self._system_messages + planning_conversation) response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str) assert isinstance(response.content, str)
self._facts = response.content self._facts = response.content
@ -176,14 +182,16 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append( planning_conversation.append(
UserMessage(content=self._get_update_plan_prompt(self._team_description), source=self.metadata["type"]) UserMessage(content=self._get_update_plan_prompt(self._team_description), source=self.metadata["type"])
) )
response = await self._model_client.create(self._system_messages + planning_conversation) response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str) assert isinstance(response.content, str)
self._plan = response.content self._plan = response.content
# At this point, the planning conversation is dropped. # At this point, the planning conversation is dropped.
async def update_ledger(self) -> Dict[str, Any]: async def update_ledger(self, cancellation_token: Optional[CancellationToken] = None) -> Dict[str, Any]:
max_json_retries = 10 max_json_retries = 10
team_description = await self._get_team_description() team_description = await self._get_team_description()
@ -197,6 +205,7 @@ class LedgerOrchestrator(BaseOrchestrator):
ledger_response = await self._model_client.create( ledger_response = await self._model_client.create(
self._system_messages + self._chat_history + ledger_user_messages, self._system_messages + self._chat_history + ledger_user_messages,
json_output=True, json_output=True,
cancellation_token=cancellation_token,
) )
ledger_str = ledger_response.content ledger_str = ledger_response.content
@ -230,7 +239,7 @@ class LedgerOrchestrator(BaseOrchestrator):
continue continue
return ledger_dict return ledger_dict
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (error)", f"{self.metadata['type']} (error)",
f"Failed to parse ledger information: {ledger_str}", f"Failed to parse ledger information: {ledger_str}",
@ -244,10 +253,12 @@ class LedgerOrchestrator(BaseOrchestrator):
self._chat_history.append(message.content) self._chat_history.append(message.content)
await super()._handle_broadcast(message, ctx) await super()._handle_broadcast(message, ctx)
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: async def _select_next_agent(
self, message: LLMMessage, cancellation_token: Optional[CancellationToken] = None
) -> Optional[AgentProxy]:
# Check if the task is still unset, in which case this message contains the task string # Check if the task is still unset, in which case this message contains the task string
if len(self._task) == 0: if len(self._task) == 0:
await self._initialize_task(self._get_message_str(message)) await self._initialize_task(self._get_message_str(message), cancellation_token)
# At this point the task, plan and facts shouls all be set # At this point the task, plan and facts shouls all be set
assert len(self._task) > 0 assert len(self._task) > 0
@ -263,9 +274,10 @@ class LedgerOrchestrator(BaseOrchestrator):
await self.publish_message( await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])), BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id, topic_id=topic_id,
cancellation_token=cancellation_token,
) )
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (thought)", f"{self.metadata['type']} (thought)",
f"Initial plan:\n{synthesized_prompt}", f"Initial plan:\n{synthesized_prompt}",
@ -279,11 +291,11 @@ class LedgerOrchestrator(BaseOrchestrator):
self._chat_history.append(synthesized_message) self._chat_history.append(synthesized_message)
# Answer from this synthesized message # Answer from this synthesized message
return await self._select_next_agent(synthesized_message) return await self._select_next_agent(synthesized_message, cancellation_token)
# Orchestrate the next step # Orchestrate the next step
ledger_dict = await self.update_ledger() ledger_dict = await self.update_ledger(cancellation_token)
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (thought)", f"{self.metadata['type']} (thought)",
f"Updated Ledger:\n{json.dumps(ledger_dict, indent=2)}", f"Updated Ledger:\n{json.dumps(ledger_dict, indent=2)}",
@ -292,7 +304,7 @@ class LedgerOrchestrator(BaseOrchestrator):
# Task is complete # Task is complete
if ledger_dict["is_request_satisfied"]["answer"] is True: if ledger_dict["is_request_satisfied"]["answer"] is True:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (thought)", f"{self.metadata['type']} (thought)",
"Request satisfied.", "Request satisfied.",
@ -312,7 +324,7 @@ class LedgerOrchestrator(BaseOrchestrator):
# We exceeded our replan counter # We exceeded our replan counter
if self._replan_counter > self._max_replans: if self._replan_counter > self._max_replans:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (thought)", f"{self.metadata['type']} (thought)",
"Replan counter exceeded... Terminating.", "Replan counter exceeded... Terminating.",
@ -321,7 +333,7 @@ class LedgerOrchestrator(BaseOrchestrator):
return None return None
# Let's create a new plan # Let's create a new plan
else: else:
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (thought)", f"{self.metadata['type']} (thought)",
"Stalled.... Replanning...", "Stalled.... Replanning...",
@ -329,24 +341,24 @@ class LedgerOrchestrator(BaseOrchestrator):
) )
# Update our plan. # Update our plan.
await self._update_facts_and_plan() await self._update_facts_and_plan(cancellation_token)
# Reset everyone, then rebroadcast the new plan # Reset everyone, then rebroadcast the new plan
self._chat_history = [self._chat_history[0]] self._chat_history = [self._chat_history[0]]
topic_id = TopicId("default", self.id.key) topic_id = TopicId("default", self.id.key)
await self.publish_message(ResetMessage(), topic_id=topic_id) await self.publish_message(ResetMessage(), topic_id=topic_id, cancellation_token=cancellation_token)
# Send everyone the NEW plan # Send everyone the NEW plan
synthesized_prompt = self._get_synthesize_prompt( synthesized_prompt = self._get_synthesize_prompt(
self._task, self._team_description, self._facts, self._plan self._task, self._team_description, self._facts, self._plan
) )
topic_id = TopicId("default", self.id.key)
await self.publish_message( await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])), BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id, topic_id=topic_id,
cancellation_token=cancellation_token,
) )
logger.info( self.logger.info(
OrchestrationEvent( OrchestrationEvent(
f"{self.metadata['type']} (thought)", f"{self.metadata['type']} (thought)",
f"New plan:\n{synthesized_prompt}", f"New plan:\n{synthesized_prompt}",
@ -357,7 +369,7 @@ class LedgerOrchestrator(BaseOrchestrator):
self._chat_history.append(synthesized_message) self._chat_history.append(synthesized_message)
# Answer from this synthesized message # Answer from this synthesized message
return await self._select_next_agent(synthesized_message) return await self._select_next_agent(synthesized_message, cancellation_token)
# If we goit this far, we were not starting, done, or stuck # If we goit this far, we were not starting, done, or stuck
next_agent_name = ledger_dict["next_speaker"]["answer"] next_agent_name = ledger_dict["next_speaker"]["answer"]
@ -367,12 +379,13 @@ class LedgerOrchestrator(BaseOrchestrator):
instruction = ledger_dict["instruction_or_question"]["answer"] instruction = ledger_dict["instruction_or_question"]["answer"]
user_message = UserMessage(content=instruction, source=self.metadata["type"]) user_message = UserMessage(content=instruction, source=self.metadata["type"])
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"]) assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction)) self.logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
self._chat_history.append(assistant_message) # My copy self._chat_history.append(assistant_message) # My copy
topic_id = TopicId("default", self.id.key) topic_id = TopicId("default", self.id.key)
await self.publish_message( await self.publish_message(
BroadcastMessage(content=user_message, request_halt=False), BroadcastMessage(content=user_message, request_halt=False),
topic_id=topic_id, topic_id=topic_id,
cancellation_token=cancellation_token,
) # Send to everyone else ) # Send to everyone else
return agent return agent