mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-05 20:34:19 +00:00
TeamOne cancellation token support and making logger a member variable (#622)
* Added cancellation token support for team_one; made logger a member variable of each agent. formatting fix error fix error formatting * No need to create a new cancellation token
This commit is contained in:
parent
6dcbf869ad
commit
afdfb4ea58
@ -15,8 +15,6 @@ from team_one.messages import (
|
|||||||
TeamOneMessages,
|
TeamOneMessages,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent")
|
|
||||||
|
|
||||||
|
|
||||||
class TeamOneBaseAgent(RoutedAgent):
|
class TeamOneBaseAgent(RoutedAgent):
|
||||||
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
|
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
|
||||||
@ -29,6 +27,7 @@ class TeamOneBaseAgent(RoutedAgent):
|
|||||||
super().__init__(description)
|
super().__init__(description)
|
||||||
self._handle_messages_concurrently = handle_messages_concurrently
|
self._handle_messages_concurrently = handle_messages_concurrently
|
||||||
self._enabled = True
|
self._enabled = True
|
||||||
|
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.agent")
|
||||||
|
|
||||||
if not self._handle_messages_concurrently:
|
if not self._handle_messages_concurrently:
|
||||||
# TODO: make it possible to stop
|
# TODO: make it possible to stop
|
||||||
@ -40,6 +39,7 @@ class TeamOneBaseAgent(RoutedAgent):
|
|||||||
message, ctx, future = await self._message_queue.get()
|
message, ctx, future = await self._message_queue.get()
|
||||||
if ctx.cancellation_token.is_cancelled():
|
if ctx.cancellation_token.is_cancelled():
|
||||||
# TODO: Do we need to resolve the future here?
|
# TODO: Do we need to resolve the future here?
|
||||||
|
future.cancel()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -54,6 +54,8 @@ class TeamOneBaseAgent(RoutedAgent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown message type.")
|
raise ValueError("Unknown message type.")
|
||||||
future.set_result(None)
|
future.set_result(None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
future.cancel()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
future.set_exception(e)
|
future.set_exception(e)
|
||||||
|
|
||||||
@ -92,9 +94,19 @@ class TeamOneBaseAgent(RoutedAgent):
|
|||||||
async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None:
|
async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None:
|
||||||
"""Handle a deactivate message."""
|
"""Handle a deactivate message."""
|
||||||
self._enabled = False
|
self._enabled = False
|
||||||
logger.info(
|
self.logger.info(
|
||||||
AgentEvent(
|
AgentEvent(
|
||||||
f"{self.metadata['type']} (deactivated)",
|
f"{self.metadata['type']} (deactivated)",
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||||
|
"""Drop the message, with a log."""
|
||||||
|
# self.logger.info(
|
||||||
|
# AgentEvent(
|
||||||
|
# f"{self.metadata['type']} (unhandled message)",
|
||||||
|
# f"Unhandled message type: {type(message)}",
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|||||||
@ -10,8 +10,6 @@ from ..messages import BroadcastMessage, OrchestrationEvent, RequestReplyMessage
|
|||||||
from ..utils import message_content_to_str
|
from ..utils import message_content_to_str
|
||||||
from .base_agent import TeamOneBaseAgent
|
from .base_agent import TeamOneBaseAgent
|
||||||
|
|
||||||
logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOrchestrator(TeamOneBaseAgent):
|
class BaseOrchestrator(TeamOneBaseAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -28,6 +26,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
self._max_time = max_time
|
self._max_time = max_time
|
||||||
self._num_rounds = 0
|
self._num_rounds = 0
|
||||||
self._start_time: float = -1.0
|
self._start_time: float = -1.0
|
||||||
|
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.orchestrator")
|
||||||
|
|
||||||
async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None:
|
async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None:
|
||||||
"""Handle an incoming message."""
|
"""Handle an incoming message."""
|
||||||
@ -42,11 +41,11 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
|
|
||||||
content = message_content_to_str(message.content.content)
|
content = message_content_to_str(message.content.content)
|
||||||
|
|
||||||
logger.info(OrchestrationEvent(source, content))
|
self.logger.info(OrchestrationEvent(source, content))
|
||||||
|
|
||||||
# Termination conditions
|
# Termination conditions
|
||||||
if self._num_rounds >= self._max_rounds:
|
if self._num_rounds >= self._max_rounds:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (termination condition)",
|
f"{self.metadata['type']} (termination condition)",
|
||||||
f"Max rounds ({self._max_rounds}) reached.",
|
f"Max rounds ({self._max_rounds}) reached.",
|
||||||
@ -55,7 +54,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if time.time() - self._start_time >= self._max_time:
|
if time.time() - self._start_time >= self._max_time:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (termination condition)",
|
f"{self.metadata['type']} (termination condition)",
|
||||||
f"Max time ({self._max_time}s) reached.",
|
f"Max time ({self._max_time}s) reached.",
|
||||||
@ -64,7 +63,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if message.request_halt:
|
if message.request_halt:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (termination condition)",
|
f"{self.metadata['type']} (termination condition)",
|
||||||
f"{source} requested halt.",
|
f"{source} requested halt.",
|
||||||
@ -74,7 +73,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
|
|
||||||
next_agent = await self._select_next_agent(message.content)
|
next_agent = await self._select_next_agent(message.content)
|
||||||
if next_agent is None:
|
if next_agent is None:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (termination condition)",
|
f"{self.metadata['type']} (termination condition)",
|
||||||
"No agent selected.",
|
"No agent selected.",
|
||||||
@ -84,7 +83,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
request_reply_message = RequestReplyMessage()
|
request_reply_message = RequestReplyMessage()
|
||||||
# emit an event
|
# emit an event
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
source=f"{self.metadata['type']} (thought)",
|
source=f"{self.metadata['type']} (thought)",
|
||||||
message=f"Next speaker {(await next_agent.metadata)['type']}" "",
|
message=f"Next speaker {(await next_agent.metadata)['type']}" "",
|
||||||
@ -92,7 +91,7 @@ class BaseOrchestrator(TeamOneBaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._num_rounds += 1 # Call before sending the message
|
self._num_rounds += 1 # Call before sending the message
|
||||||
await self.send_message(request_reply_message, next_agent.id)
|
await self.send_message(request_reply_message, next_agent.id, cancellation_token=ctx.cancellation_token)
|
||||||
|
|
||||||
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
|
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@ -46,7 +46,11 @@ class BaseWorker(TeamOneBaseAgent):
|
|||||||
|
|
||||||
user_message = UserMessage(content=response, source=self.metadata["type"])
|
user_message = UserMessage(content=response, source=self.metadata["type"])
|
||||||
topic_id = TopicId("default", self.id.key)
|
topic_id = TopicId("default", self.id.key)
|
||||||
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id)
|
await self.publish_message(
|
||||||
|
BroadcastMessage(content=user_message, request_halt=request_halt),
|
||||||
|
topic_id=topic_id,
|
||||||
|
cancellation_token=ctx.cancellation_token,
|
||||||
|
)
|
||||||
|
|
||||||
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
|
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
|
||||||
"""Returns (request_halt, response_message)"""
|
"""Returns (request_halt, response_message)"""
|
||||||
|
|||||||
@ -49,7 +49,9 @@ Reply "TERMINATE" in the end when everything is done.""")
|
|||||||
"""Respond to a reply request."""
|
"""Respond to a reply request."""
|
||||||
|
|
||||||
# Make an inference to the model.
|
# Make an inference to the model.
|
||||||
response = await self._model_client.create(self._system_messages + self._chat_history)
|
response = await self._model_client.create(
|
||||||
|
self._system_messages + self._chat_history, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
return "TERMINATE" in response.content, response.content
|
return "TERMINATE" in response.content, response.content
|
||||||
|
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class FileSurfer(BaseWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
create_result = await self._model_client.create(
|
create_result = await self._model_client.create(
|
||||||
messages=history + [context_message, task_message], tools=self._tools
|
messages=history + [context_message, task_message], tools=self._tools, cancellation_token=cancellation_token
|
||||||
)
|
)
|
||||||
|
|
||||||
response = create_result.content
|
response = create_result.content
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import os
|
|||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callable, Dict, List, Literal, Tuple
|
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast, Optional # Any, Callable, Dict, List, Literal, Tuple
|
||||||
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse
|
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
@ -67,8 +67,6 @@ MLM_WIDTH = 1224
|
|||||||
|
|
||||||
SCREENSHOT_TOKENS = 1105
|
SCREENSHOT_TOKENS = 1105
|
||||||
|
|
||||||
logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer")
|
|
||||||
|
|
||||||
|
|
||||||
# Sentinels
|
# Sentinels
|
||||||
class DEFAULT_CHANNEL(metaclass=SentinelMeta):
|
class DEFAULT_CHANNEL(metaclass=SentinelMeta):
|
||||||
@ -96,6 +94,7 @@ class MultimodalWebSurfer(BaseWorker):
|
|||||||
self._page: Page | None = None
|
self._page: Page | None = None
|
||||||
self._last_download: Download | None = None
|
self._last_download: Download | None = None
|
||||||
self._prior_metadata_hash: str | None = None
|
self._prior_metadata_hash: str | None = None
|
||||||
|
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.MultimodalWebSurfer")
|
||||||
|
|
||||||
# Read page_script
|
# Read page_script
|
||||||
self._page_script: str = ""
|
self._page_script: str = ""
|
||||||
@ -196,7 +195,7 @@ setInterval(function() {{
|
|||||||
""".strip(),
|
""".strip(),
|
||||||
)
|
)
|
||||||
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
|
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
|
||||||
logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
|
self.logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
|
||||||
|
|
||||||
async def _reset(self, cancellation_token: CancellationToken) -> None:
|
async def _reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
assert self._page is not None
|
assert self._page is not None
|
||||||
@ -205,7 +204,7 @@ setInterval(function() {{
|
|||||||
await self._visit_page(self.start_page)
|
await self._visit_page(self.start_page)
|
||||||
if self.debug_dir:
|
if self.debug_dir:
|
||||||
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
|
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
|
||||||
logger.info(
|
self.logger.info(
|
||||||
WebSurferEvent(
|
WebSurferEvent(
|
||||||
source=self.metadata["type"],
|
source=self.metadata["type"],
|
||||||
url=self._page.url,
|
url=self._page.url,
|
||||||
@ -250,13 +249,18 @@ setInterval(function() {{
|
|||||||
return False, f"Web surfing error:\n\n{traceback.format_exc()}"
|
return False, f"Web surfing error:\n\n{traceback.format_exc()}"
|
||||||
|
|
||||||
async def _execute_tool(
|
async def _execute_tool(
|
||||||
self, message: List[FunctionCall], rects: Dict[str, InteractiveRegion], tool_names: str, use_ocr: bool = True
|
self,
|
||||||
|
message: List[FunctionCall],
|
||||||
|
rects: Dict[str, InteractiveRegion],
|
||||||
|
tool_names: str,
|
||||||
|
use_ocr: bool = True,
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
) -> Tuple[bool, UserContent]:
|
) -> Tuple[bool, UserContent]:
|
||||||
name = message[0].name
|
name = message[0].name
|
||||||
args = json.loads(message[0].arguments)
|
args = json.loads(message[0].arguments)
|
||||||
action_description = ""
|
action_description = ""
|
||||||
assert self._page is not None
|
assert self._page is not None
|
||||||
logger.info(
|
self.logger.info(
|
||||||
WebSurferEvent(
|
WebSurferEvent(
|
||||||
source=self.metadata["type"],
|
source=self.metadata["type"],
|
||||||
url=self._page.url,
|
url=self._page.url,
|
||||||
@ -340,11 +344,11 @@ setInterval(function() {{
|
|||||||
elif name == "answer_question":
|
elif name == "answer_question":
|
||||||
question = str(args.get("question"))
|
question = str(args.get("question"))
|
||||||
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
|
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
|
||||||
return False, await self._summarize_page(question=question)
|
return False, await self._summarize_page(question=question, cancellation_token=cancellation_token)
|
||||||
|
|
||||||
elif name == "summarize_page":
|
elif name == "summarize_page":
|
||||||
# Summarize the DOM. No need to take further action. Browser state does not change.
|
# Summarize the DOM. No need to take further action. Browser state does not change.
|
||||||
return False, await self._summarize_page()
|
return False, await self._summarize_page(cancellation_token=cancellation_token)
|
||||||
|
|
||||||
elif name == "sleep":
|
elif name == "sleep":
|
||||||
action_description = "I am waiting a short period of time before taking further action."
|
action_description = "I am waiting a short period of time before taking further action."
|
||||||
@ -394,7 +398,9 @@ setInterval(function() {{
|
|||||||
async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file:
|
async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file:
|
||||||
await file.write(new_screenshot)
|
await file.write(new_screenshot)
|
||||||
|
|
||||||
ocr_text = await self._get_ocr_text(new_screenshot) if use_ocr is True else ""
|
ocr_text = (
|
||||||
|
await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) if use_ocr is True else ""
|
||||||
|
)
|
||||||
|
|
||||||
# Return the complete observation
|
# Return the complete observation
|
||||||
message_content = "" # message.content or ""
|
message_content = "" # message.content or ""
|
||||||
@ -518,7 +524,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"])
|
UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"])
|
||||||
)
|
)
|
||||||
response = await self._model_client.create(
|
response = await self._model_client.create(
|
||||||
history, tools=tools, extra_create_args={"tool_choice": "auto"}
|
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
|
||||||
) # , "parallel_tool_calls": False})
|
) # , "parallel_tool_calls": False})
|
||||||
message = response.content
|
message = response.content
|
||||||
|
|
||||||
@ -529,7 +535,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
return False, message
|
return False, message
|
||||||
elif isinstance(message, list):
|
elif isinstance(message, list):
|
||||||
# Take an action
|
# Take an action
|
||||||
return await self._execute_tool(message, rects, tool_names)
|
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
|
||||||
else:
|
else:
|
||||||
# Not sure what happened here
|
# Not sure what happened here
|
||||||
raise AssertionError(f"Unknown response format '{message}'")
|
raise AssertionError(f"Unknown response format '{message}'")
|
||||||
@ -668,7 +674,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
assert isinstance(new_page, Page)
|
assert isinstance(new_page, Page)
|
||||||
await self._on_new_page(new_page)
|
await self._on_new_page(new_page)
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
WebSurferEvent(
|
WebSurferEvent(
|
||||||
source=self.metadata["type"],
|
source=self.metadata["type"],
|
||||||
url=self._page.url,
|
url=self._page.url,
|
||||||
@ -716,7 +722,12 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _summarize_page(self, question: str | None = None, token_limit: int = 100000) -> str:
|
async def _summarize_page(
|
||||||
|
self,
|
||||||
|
question: str | None = None,
|
||||||
|
token_limit: int = 100000,
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
) -> str:
|
||||||
assert self._page is not None
|
assert self._page is not None
|
||||||
|
|
||||||
page_markdown: str = await self._get_page_markdown()
|
page_markdown: str = await self._get_page_markdown()
|
||||||
@ -780,12 +791,14 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Generate the response
|
# Generate the response
|
||||||
response = await self._model_client.create(messages)
|
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
|
||||||
scaled_screenshot.close()
|
scaled_screenshot.close()
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) -> str:
|
async def _get_ocr_text(
|
||||||
|
self, image: bytes | io.BufferedIOBase | Image.Image, cancellation_token: Optional[CancellationToken] = None
|
||||||
|
) -> str:
|
||||||
scaled_screenshot = None
|
scaled_screenshot = None
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT))
|
scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT))
|
||||||
@ -810,7 +823,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
source=self.metadata["type"],
|
source=self.metadata["type"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await self._model_client.create(messages)
|
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
|
||||||
scaled_screenshot.close()
|
scaled_screenshot.close()
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
return response.content
|
return response.content
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from autogen_core.base import AgentProxy, MessageContext, TopicId
|
from autogen_core.base import AgentProxy, MessageContext, TopicId, CancellationToken
|
||||||
from autogen_core.components import default_subscription
|
from autogen_core.components import default_subscription
|
||||||
from autogen_core.components.models import (
|
from autogen_core.components.models import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
@ -12,7 +12,7 @@ from autogen_core.components.models import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
|
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
|
||||||
from .base_orchestrator import BaseOrchestrator, logger
|
from .base_orchestrator import BaseOrchestrator
|
||||||
from .orchestrator_prompts import (
|
from .orchestrator_prompts import (
|
||||||
ORCHESTRATOR_CLOSED_BOOK_PROMPT,
|
ORCHESTRATOR_CLOSED_BOOK_PROMPT,
|
||||||
ORCHESTRATOR_LEDGER_PROMPT,
|
ORCHESTRATOR_LEDGER_PROMPT,
|
||||||
@ -128,7 +128,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _initialize_task(self, task: str) -> None:
|
async def _initialize_task(self, task: str, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||||
self._task = task
|
self._task = task
|
||||||
self._team_description = await self._get_team_description()
|
self._team_description = await self._get_team_description()
|
||||||
|
|
||||||
@ -140,7 +140,9 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
planning_conversation.append(
|
planning_conversation.append(
|
||||||
UserMessage(content=self._get_closed_book_prompt(self._task), source=self.metadata["type"])
|
UserMessage(content=self._get_closed_book_prompt(self._task), source=self.metadata["type"])
|
||||||
)
|
)
|
||||||
response = await self._model_client.create(self._system_messages + planning_conversation)
|
response = await self._model_client.create(
|
||||||
|
self._system_messages + planning_conversation, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
self._facts = response.content
|
self._facts = response.content
|
||||||
@ -151,14 +153,16 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
planning_conversation.append(
|
planning_conversation.append(
|
||||||
UserMessage(content=self._get_plan_prompt(self._team_description), source=self.metadata["type"])
|
UserMessage(content=self._get_plan_prompt(self._team_description), source=self.metadata["type"])
|
||||||
)
|
)
|
||||||
response = await self._model_client.create(self._system_messages + planning_conversation)
|
response = await self._model_client.create(
|
||||||
|
self._system_messages + planning_conversation, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
self._plan = response.content
|
self._plan = response.content
|
||||||
|
|
||||||
# At this point, the planning conversation is dropped.
|
# At this point, the planning conversation is dropped.
|
||||||
|
|
||||||
async def _update_facts_and_plan(self) -> None:
|
async def _update_facts_and_plan(self, cancellation_token: Optional[CancellationToken] = None) -> None:
|
||||||
# Shallow-copy the conversation
|
# Shallow-copy the conversation
|
||||||
planning_conversation = [m for m in self._chat_history]
|
planning_conversation = [m for m in self._chat_history]
|
||||||
|
|
||||||
@ -166,7 +170,9 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
planning_conversation.append(
|
planning_conversation.append(
|
||||||
UserMessage(content=self._get_update_facts_prompt(self._task, self._facts), source=self.metadata["type"])
|
UserMessage(content=self._get_update_facts_prompt(self._task, self._facts), source=self.metadata["type"])
|
||||||
)
|
)
|
||||||
response = await self._model_client.create(self._system_messages + planning_conversation)
|
response = await self._model_client.create(
|
||||||
|
self._system_messages + planning_conversation, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
self._facts = response.content
|
self._facts = response.content
|
||||||
@ -176,14 +182,16 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
planning_conversation.append(
|
planning_conversation.append(
|
||||||
UserMessage(content=self._get_update_plan_prompt(self._team_description), source=self.metadata["type"])
|
UserMessage(content=self._get_update_plan_prompt(self._team_description), source=self.metadata["type"])
|
||||||
)
|
)
|
||||||
response = await self._model_client.create(self._system_messages + planning_conversation)
|
response = await self._model_client.create(
|
||||||
|
self._system_messages + planning_conversation, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
self._plan = response.content
|
self._plan = response.content
|
||||||
|
|
||||||
# At this point, the planning conversation is dropped.
|
# At this point, the planning conversation is dropped.
|
||||||
|
|
||||||
async def update_ledger(self) -> Dict[str, Any]:
|
async def update_ledger(self, cancellation_token: Optional[CancellationToken] = None) -> Dict[str, Any]:
|
||||||
max_json_retries = 10
|
max_json_retries = 10
|
||||||
|
|
||||||
team_description = await self._get_team_description()
|
team_description = await self._get_team_description()
|
||||||
@ -197,6 +205,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
ledger_response = await self._model_client.create(
|
ledger_response = await self._model_client.create(
|
||||||
self._system_messages + self._chat_history + ledger_user_messages,
|
self._system_messages + self._chat_history + ledger_user_messages,
|
||||||
json_output=True,
|
json_output=True,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
ledger_str = ledger_response.content
|
ledger_str = ledger_response.content
|
||||||
|
|
||||||
@ -230,7 +239,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
continue
|
continue
|
||||||
return ledger_dict
|
return ledger_dict
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (error)",
|
f"{self.metadata['type']} (error)",
|
||||||
f"Failed to parse ledger information: {ledger_str}",
|
f"Failed to parse ledger information: {ledger_str}",
|
||||||
@ -244,10 +253,12 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
self._chat_history.append(message.content)
|
self._chat_history.append(message.content)
|
||||||
await super()._handle_broadcast(message, ctx)
|
await super()._handle_broadcast(message, ctx)
|
||||||
|
|
||||||
async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
|
async def _select_next_agent(
|
||||||
|
self, message: LLMMessage, cancellation_token: Optional[CancellationToken] = None
|
||||||
|
) -> Optional[AgentProxy]:
|
||||||
# Check if the task is still unset, in which case this message contains the task string
|
# Check if the task is still unset, in which case this message contains the task string
|
||||||
if len(self._task) == 0:
|
if len(self._task) == 0:
|
||||||
await self._initialize_task(self._get_message_str(message))
|
await self._initialize_task(self._get_message_str(message), cancellation_token)
|
||||||
|
|
||||||
# At this point the task, plan and facts shouls all be set
|
# At this point the task, plan and facts shouls all be set
|
||||||
assert len(self._task) > 0
|
assert len(self._task) > 0
|
||||||
@ -263,9 +274,10 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
await self.publish_message(
|
await self.publish_message(
|
||||||
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
|
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
|
||||||
topic_id=topic_id,
|
topic_id=topic_id,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (thought)",
|
f"{self.metadata['type']} (thought)",
|
||||||
f"Initial plan:\n{synthesized_prompt}",
|
f"Initial plan:\n{synthesized_prompt}",
|
||||||
@ -279,11 +291,11 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
self._chat_history.append(synthesized_message)
|
self._chat_history.append(synthesized_message)
|
||||||
|
|
||||||
# Answer from this synthesized message
|
# Answer from this synthesized message
|
||||||
return await self._select_next_agent(synthesized_message)
|
return await self._select_next_agent(synthesized_message, cancellation_token)
|
||||||
|
|
||||||
# Orchestrate the next step
|
# Orchestrate the next step
|
||||||
ledger_dict = await self.update_ledger()
|
ledger_dict = await self.update_ledger(cancellation_token)
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (thought)",
|
f"{self.metadata['type']} (thought)",
|
||||||
f"Updated Ledger:\n{json.dumps(ledger_dict, indent=2)}",
|
f"Updated Ledger:\n{json.dumps(ledger_dict, indent=2)}",
|
||||||
@ -292,7 +304,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
|
|
||||||
# Task is complete
|
# Task is complete
|
||||||
if ledger_dict["is_request_satisfied"]["answer"] is True:
|
if ledger_dict["is_request_satisfied"]["answer"] is True:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (thought)",
|
f"{self.metadata['type']} (thought)",
|
||||||
"Request satisfied.",
|
"Request satisfied.",
|
||||||
@ -312,7 +324,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
|
|
||||||
# We exceeded our replan counter
|
# We exceeded our replan counter
|
||||||
if self._replan_counter > self._max_replans:
|
if self._replan_counter > self._max_replans:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (thought)",
|
f"{self.metadata['type']} (thought)",
|
||||||
"Replan counter exceeded... Terminating.",
|
"Replan counter exceeded... Terminating.",
|
||||||
@ -321,7 +333,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
return None
|
return None
|
||||||
# Let's create a new plan
|
# Let's create a new plan
|
||||||
else:
|
else:
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (thought)",
|
f"{self.metadata['type']} (thought)",
|
||||||
"Stalled.... Replanning...",
|
"Stalled.... Replanning...",
|
||||||
@ -329,24 +341,24 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update our plan.
|
# Update our plan.
|
||||||
await self._update_facts_and_plan()
|
await self._update_facts_and_plan(cancellation_token)
|
||||||
|
|
||||||
# Reset everyone, then rebroadcast the new plan
|
# Reset everyone, then rebroadcast the new plan
|
||||||
self._chat_history = [self._chat_history[0]]
|
self._chat_history = [self._chat_history[0]]
|
||||||
topic_id = TopicId("default", self.id.key)
|
topic_id = TopicId("default", self.id.key)
|
||||||
await self.publish_message(ResetMessage(), topic_id=topic_id)
|
await self.publish_message(ResetMessage(), topic_id=topic_id, cancellation_token=cancellation_token)
|
||||||
|
|
||||||
# Send everyone the NEW plan
|
# Send everyone the NEW plan
|
||||||
synthesized_prompt = self._get_synthesize_prompt(
|
synthesized_prompt = self._get_synthesize_prompt(
|
||||||
self._task, self._team_description, self._facts, self._plan
|
self._task, self._team_description, self._facts, self._plan
|
||||||
)
|
)
|
||||||
topic_id = TopicId("default", self.id.key)
|
|
||||||
await self.publish_message(
|
await self.publish_message(
|
||||||
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
|
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
|
||||||
topic_id=topic_id,
|
topic_id=topic_id,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
self.logger.info(
|
||||||
OrchestrationEvent(
|
OrchestrationEvent(
|
||||||
f"{self.metadata['type']} (thought)",
|
f"{self.metadata['type']} (thought)",
|
||||||
f"New plan:\n{synthesized_prompt}",
|
f"New plan:\n{synthesized_prompt}",
|
||||||
@ -357,7 +369,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
self._chat_history.append(synthesized_message)
|
self._chat_history.append(synthesized_message)
|
||||||
|
|
||||||
# Answer from this synthesized message
|
# Answer from this synthesized message
|
||||||
return await self._select_next_agent(synthesized_message)
|
return await self._select_next_agent(synthesized_message, cancellation_token)
|
||||||
|
|
||||||
# If we goit this far, we were not starting, done, or stuck
|
# If we goit this far, we were not starting, done, or stuck
|
||||||
next_agent_name = ledger_dict["next_speaker"]["answer"]
|
next_agent_name = ledger_dict["next_speaker"]["answer"]
|
||||||
@ -367,12 +379,13 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
instruction = ledger_dict["instruction_or_question"]["answer"]
|
instruction = ledger_dict["instruction_or_question"]["answer"]
|
||||||
user_message = UserMessage(content=instruction, source=self.metadata["type"])
|
user_message = UserMessage(content=instruction, source=self.metadata["type"])
|
||||||
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
|
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
|
||||||
logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
|
self.logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
|
||||||
self._chat_history.append(assistant_message) # My copy
|
self._chat_history.append(assistant_message) # My copy
|
||||||
topic_id = TopicId("default", self.id.key)
|
topic_id = TopicId("default", self.id.key)
|
||||||
await self.publish_message(
|
await self.publish_message(
|
||||||
BroadcastMessage(content=user_message, request_halt=False),
|
BroadcastMessage(content=user_message, request_halt=False),
|
||||||
topic_id=topic_id,
|
topic_id=topic_id,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
) # Send to everyone else
|
) # Send to everyone else
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user