TeamOne cancellation token support and making logger a member variable (#622)

* Added cancellation token support for team_one; made logger a member variable of each agent.

formatting

fix error

fix error

formatting

* No need to create a new cancellation token
This commit is contained in:
Enhao Zhang 2024-09-24 16:54:22 -07:00 committed by GitHub
parent 6dcbf869ad
commit afdfb4ea58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 100 additions and 57 deletions

View File

@ -15,8 +15,6 @@ from team_one.messages import (
TeamOneMessages,
)
logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent")
class TeamOneBaseAgent(RoutedAgent):
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
@ -29,6 +27,7 @@ class TeamOneBaseAgent(RoutedAgent):
super().__init__(description)
self._handle_messages_concurrently = handle_messages_concurrently
self._enabled = True
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.agent")
if not self._handle_messages_concurrently:
# TODO: make it possible to stop
@ -40,6 +39,7 @@ class TeamOneBaseAgent(RoutedAgent):
message, ctx, future = await self._message_queue.get()
if ctx.cancellation_token.is_cancelled():
# TODO: Do we need to resolve the future here?
future.cancel()
continue
try:
@ -54,6 +54,8 @@ class TeamOneBaseAgent(RoutedAgent):
else:
raise ValueError("Unknown message type.")
future.set_result(None)
except asyncio.CancelledError:
future.cancel()
except Exception as e:
future.set_exception(e)
@ -92,9 +94,19 @@ class TeamOneBaseAgent(RoutedAgent):
async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None:
"""Handle a deactivate message."""
self._enabled = False
logger.info(
self.logger.info(
AgentEvent(
f"{self.metadata['type']} (deactivated)",
"",
)
)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
"""Drop the message, with a log."""
# self.logger.info(
# AgentEvent(
# f"{self.metadata['type']} (unhandled message)",
# f"Unhandled message type: {type(message)}",
# )
# )
pass

View File

@ -10,8 +10,6 @@ from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
from ..utils import message_content_to_str
from .base_agent import TeamOneBaseAgent
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
class BaseOrchestrator(TeamOneBaseAgent):
def __init__(
@ -28,6 +26,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
self._max_time = max_time
self._num_rounds = 0
self._start_time: float = -1.0
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.orchestrator")
async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None:
"""Handle an incoming message."""
@ -42,11 +41,11 @@ class BaseOrchestrator(TeamOneBaseAgent):
content = message_content_to_str(message.content.content)
logger.info(OrchestrationEvent(source, content))
self.logger.info(OrchestrationEvent(source, content))
# Termination conditions
if self._num_rounds >= self._max_rounds:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
f"Max rounds ({self._max_rounds}) reached.",
@ -55,7 +54,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
return
if time.time() - self._start_time >= self._max_time:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
f"Max time ({self._max_time}s) reached.",
@ -64,7 +63,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
return
if message.request_halt:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
f"{source} requested halt.",
@ -74,7 +73,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
next_agent = await self._select_next_agent(message.content)
if next_agent is None:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
"No agent selected.",
@ -84,7 +83,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
request_reply_message = RequestReplyMessage()
# emit an event
logger.info(
self.logger.info(
OrchestrationEvent(
source=f"{self.metadata['type']} (thought)",
message=f"Next speaker {(await next_agent.metadata)['type']}" "",
@ -92,7 +91,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
)
self._num_rounds += 1 # Call before sending the message
await self.send_message(request_reply_message, next_agent.id)
await self.send_message(request_reply_message, next_agent.id, cancellation_token=ctx.cancellation_token)
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
raise NotImplementedError()

View File

@ -46,7 +46,11 @@ class BaseWorker(TeamOneBaseAgent):
user_message = UserMessage(content=response, source=self.metadata["type"])
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id)
await self.publish_message(
BroadcastMessage(content=user_message, request_halt=request_halt),
topic_id=topic_id,
cancellation_token=ctx.cancellation_token,
)
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""

View File

@ -49,7 +49,9 @@ Reply "TERMINATE" in the end when everything is done.""")
"""Respond to a reply request."""
# Make an inference to the model.
response = await self._model_client.create(self._system_messages + self._chat_history)
response = await self._model_client.create(
self._system_messages + self._chat_history, cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
return "TERMINATE" in response.content, response.content

View File

@ -90,7 +90,7 @@ class FileSurfer(BaseWorker):
)
create_result = await self._model_client.create(
messages=history + [context_message, task_message], tools=self._tools
messages=history + [context_message, task_message], tools=self._tools, cancellation_token=cancellation_token
)
response = create_result.content

View File

@ -7,7 +7,7 @@ import os
import pathlib
import re
import traceback
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callable, Dict, List, Literal, Tuple
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast, Optional # Any, Callable, Dict, List, Literal, Tuple
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse
import aiofiles
@ -67,8 +67,6 @@ MLM_WIDTH = 1224
SCREENSHOT_TOKENS = 1105
logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer")
# Sentinels
class DEFAULT_CHANNEL(metaclass=SentinelMeta):
@ -96,6 +94,7 @@ class MultimodalWebSurfer(BaseWorker):
self._page: Page | None = None
self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.MultimodalWebSurfer")
# Read page_script
self._page_script: str = ""
@ -196,7 +195,7 @@ setInterval(function() {{
""".strip(),
)
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
self.logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
async def _reset(self, cancellation_token: CancellationToken) -> None:
assert self._page is not None
@ -205,7 +204,7 @@ setInterval(function() {{
await self._visit_page(self.start_page)
if self.debug_dir:
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(
self.logger.info(
WebSurferEvent(
source=self.metadata["type"],
url=self._page.url,
@ -250,13 +249,18 @@ setInterval(function() {{
return False, f"Web surfing error:\n\n{traceback.format_exc()}"
async def _execute_tool(
self, message: List[FunctionCall], rects: Dict[str, InteractiveRegion], tool_names: str, use_ocr: bool = True
self,
message: List[FunctionCall],
rects: Dict[str, InteractiveRegion],
tool_names: str,
use_ocr: bool = True,
cancellation_token: Optional[CancellationToken] = None,
) -> Tuple[bool, UserContent]:
name = message[0].name
args = json.loads(message[0].arguments)
action_description = ""
assert self._page is not None
logger.info(
self.logger.info(
WebSurferEvent(
source=self.metadata["type"],
url=self._page.url,
@ -340,11 +344,11 @@ setInterval(function() {{
elif name == "answer_question":
question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question)
return False, await self._summarize_page(question=question, cancellation_token=cancellation_token)
elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page()
return False, await self._summarize_page(cancellation_token=cancellation_token)
elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
@ -394,7 +398,9 @@ setInterval(function() {{
async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file:
await file.write(new_screenshot)
ocr_text = await self._get_ocr_text(new_screenshot) if use_ocr is True else ""
ocr_text = (
await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) if use_ocr is True else ""
)
# Return the complete observation
message_content = "" # message.content or ""
@ -518,7 +524,7 @@ When deciding between tools, consider if the request can be best addressed by:
UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"])
)
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False})
message = response.content
@ -529,7 +535,7 @@ When deciding between tools, consider if the request can be best addressed by:
return False, message
elif isinstance(message, list):
# Take an action
return await self._execute_tool(message, rects, tool_names)
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
@ -668,7 +674,7 @@ When deciding between tools, consider if the request can be best addressed by:
assert isinstance(new_page, Page)
await self._on_new_page(new_page)
logger.info(
self.logger.info(
WebSurferEvent(
source=self.metadata["type"],
url=self._page.url,
@ -716,7 +722,12 @@ When deciding between tools, consider if the request can be best addressed by:
"""
)
async def _summarize_page(self, question: str | None = None, token_limit: int = 100000) -> str:
async def _summarize_page(
self,
question: str | None = None,
token_limit: int = 100000,
cancellation_token: Optional[CancellationToken] = None,
) -> str:
assert self._page is not None
page_markdown: str = await self._get_page_markdown()
@ -780,12 +791,14 @@ When deciding between tools, consider if the request can be best addressed by:
)
# Generate the response
response = await self._model_client.create(messages)
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content
async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) -> str:
async def _get_ocr_text(
self, image: bytes | io.BufferedIOBase | Image.Image, cancellation_token: Optional[CancellationToken] = None
) -> str:
scaled_screenshot = None
if isinstance(image, Image.Image):
scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT))
@ -810,7 +823,7 @@ When deciding between tools, consider if the request can be best addressed by:
source=self.metadata["type"],
)
)
response = await self._model_client.create(messages)
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content

View File

@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Optional
from autogen_core.base import AgentProxy, MessageContext, TopicId
from autogen_core.base import AgentProxy, MessageContext, TopicId, CancellationToken
from autogen_core.components import default_subscription
from autogen_core.components.models import (
AssistantMessage,
@ -12,7 +12,7 @@ from autogen_core.components.models import (
)
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
from .base_orchestrator import BaseOrchestrator, logger
from .base_orchestrator import BaseOrchestrator
from .orchestrator_prompts import (
ORCHESTRATOR_CLOSED_BOOK_PROMPT,
ORCHESTRATOR_LEDGER_PROMPT,
@ -128,7 +128,7 @@ class LedgerOrchestrator(BaseOrchestrator):
assert len(result) > 0
return result
async def _initialize_task(self, task: str) -> None:
async def _initialize_task(self, task: str, cancellation_token: Optional[CancellationToken] = None) -> None:
self._task = task
self._team_description = await self._get_team_description()
@ -140,7 +140,9 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append(
UserMessage(content=self._get_closed_book_prompt(self._task), source=self.metadata["type"])
)
response = await self._model_client.create(self._system_messages + planning_conversation)
response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
self._facts = response.content
@ -151,14 +153,16 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append(
UserMessage(content=self._get_plan_prompt(self._team_description), source=self.metadata["type"])
)
response = await self._model_client.create(self._system_messages + planning_conversation)
response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
self._plan = response.content
# At this point, the planning conversation is dropped.
async def _update_facts_and_plan(self) -> None:
async def _update_facts_and_plan(self, cancellation_token: Optional[CancellationToken] = None) -> None:
# Shallow-copy the conversation
planning_conversation = [m for m in self._chat_history]
@ -166,7 +170,9 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append(
UserMessage(content=self._get_update_facts_prompt(self._task, self._facts), source=self.metadata["type"])
)
response = await self._model_client.create(self._system_messages + planning_conversation)
response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
self._facts = response.content
@ -176,14 +182,16 @@ class LedgerOrchestrator(BaseOrchestrator):
planning_conversation.append(
UserMessage(content=self._get_update_plan_prompt(self._team_description), source=self.metadata["type"])
)
response = await self._model_client.create(self._system_messages + planning_conversation)
response = await self._model_client.create(
self._system_messages + planning_conversation, cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
self._plan = response.content
# At this point, the planning conversation is dropped.
async def update_ledger(self) -> Dict[str, Any]:
async def update_ledger(self, cancellation_token: Optional[CancellationToken] = None) -> Dict[str, Any]:
max_json_retries = 10
team_description = await self._get_team_description()
@ -197,6 +205,7 @@ class LedgerOrchestrator(BaseOrchestrator):
ledger_response = await self._model_client.create(
self._system_messages + self._chat_history + ledger_user_messages,
json_output=True,
cancellation_token=cancellation_token,
)
ledger_str = ledger_response.content
@ -230,7 +239,7 @@ class LedgerOrchestrator(BaseOrchestrator):
continue
return ledger_dict
except json.JSONDecodeError as e:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (error)",
f"Failed to parse ledger information: {ledger_str}",
@ -244,10 +253,12 @@ class LedgerOrchestrator(BaseOrchestrator):
self._chat_history.append(message.content)
await super()._handle_broadcast(message, ctx)
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
async def _select_next_agent(
self, message: LLMMessage, cancellation_token: Optional[CancellationToken] = None
) -> Optional[AgentProxy]:
# Check if the task is still unset, in which case this message contains the task string
if len(self._task) == 0:
await self._initialize_task(self._get_message_str(message))
await self._initialize_task(self._get_message_str(message), cancellation_token)
# At this point the task, plan and facts shouls all be set
assert len(self._task) > 0
@ -263,9 +274,10 @@ class LedgerOrchestrator(BaseOrchestrator):
await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id,
cancellation_token=cancellation_token,
)
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (thought)",
f"Initial plan:\n{synthesized_prompt}",
@ -279,11 +291,11 @@ class LedgerOrchestrator(BaseOrchestrator):
self._chat_history.append(synthesized_message)
# Answer from this synthesized message
return await self._select_next_agent(synthesized_message)
return await self._select_next_agent(synthesized_message, cancellation_token)
# Orchestrate the next step
ledger_dict = await self.update_ledger()
logger.info(
ledger_dict = await self.update_ledger(cancellation_token)
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (thought)",
f"Updated Ledger:\n{json.dumps(ledger_dict, indent=2)}",
@ -292,7 +304,7 @@ class LedgerOrchestrator(BaseOrchestrator):
# Task is complete
if ledger_dict["is_request_satisfied"]["answer"] is True:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (thought)",
"Request satisfied.",
@ -312,7 +324,7 @@ class LedgerOrchestrator(BaseOrchestrator):
# We exceeded our replan counter
if self._replan_counter > self._max_replans:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (thought)",
"Replan counter exceeded... Terminating.",
@ -321,7 +333,7 @@ class LedgerOrchestrator(BaseOrchestrator):
return None
# Let's create a new plan
else:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (thought)",
"Stalled.... Replanning...",
@ -329,24 +341,24 @@ class LedgerOrchestrator(BaseOrchestrator):
)
# Update our plan.
await self._update_facts_and_plan()
await self._update_facts_and_plan(cancellation_token)
# Reset everyone, then rebroadcast the new plan
self._chat_history = [self._chat_history[0]]
topic_id = TopicId("default", self.id.key)
await self.publish_message(ResetMessage(), topic_id=topic_id)
await self.publish_message(ResetMessage(), topic_id=topic_id, cancellation_token=cancellation_token)
# Send everyone the NEW plan
synthesized_prompt = self._get_synthesize_prompt(
self._task, self._team_description, self._facts, self._plan
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id,
cancellation_token=cancellation_token,
)
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (thought)",
f"New plan:\n{synthesized_prompt}",
@ -357,7 +369,7 @@ class LedgerOrchestrator(BaseOrchestrator):
self._chat_history.append(synthesized_message)
# Answer from this synthesized message
return await self._select_next_agent(synthesized_message)
return await self._select_next_agent(synthesized_message, cancellation_token)
# If we goit this far, we were not starting, done, or stuck
next_agent_name = ledger_dict["next_speaker"]["answer"]
@ -367,12 +379,13 @@ class LedgerOrchestrator(BaseOrchestrator):
instruction = ledger_dict["instruction_or_question"]["answer"]
user_message = UserMessage(content=instruction, source=self.metadata["type"])
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
self.logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
self._chat_history.append(assistant_message) # My copy
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=user_message, request_halt=False),
topic_id=topic_id,
cancellation_token=cancellation_token,
) # Send to everyone else
return agent