403 lines
18 KiB
Python
Raw Normal View History

import asyncio
import os
import re
import logging
import yaml
import warnings
import contextvars
import builtins
import shutil
import json
from datetime import datetime
from typing import List, Optional, Dict
from collections import deque
from autogen_agentchat import TRACE_LOGGER_NAME as AGENTCHAT_TRACE_LOGGER_NAME, EVENT_LOGGER_NAME as AGENTCHAT_EVENT_LOGGER_NAME
from autogen_core import TRACE_LOGGER_NAME as CORE_TRACE_LOGGER_NAME, EVENT_LOGGER_NAME as CORE_EVENT_LOGGER_NAME
from autogen_ext.agents.magentic_one import MagenticOneCoderAgent
from autogen_agentchat.teams import MagenticOneGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
LLMMessage,
UserMessage,
)
from autogen_core.logging import LLMCallEvent
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_agentchat.conditions import TextMentionTermination
from autogen_core.models import ChatCompletionClient
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
from autogen_ext.agents.file_surfer import FileSurfer
from autogen_agentchat.agents import CodeExecutorAgent
from autogen_agentchat.messages import (
TextMessage,
AgentEvent,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.openai._model_info import _MODEL_TOKEN_LIMITS, resolve_model
from autogen_agentchat.utils import content_to_str
# Suppress warnings about the requests.Session() not being closed
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
core_event_logger = logging.getLogger(CORE_EVENT_LOGGER_NAME)
agentchat_event_logger = logging.getLogger(AGENTCHAT_EVENT_LOGGER_NAME)
agentchat_trace_logger = logging.getLogger(AGENTCHAT_TRACE_LOGGER_NAME)
# Create a context variable to hold the current team's log file and the current team id.
current_log_file = contextvars.ContextVar("current_log_file", default=None)
current_team_id = contextvars.ContextVar("current_team_id", default=None)
# Save the original print function and event_logger.info method.
original_print = builtins.print
original_agentchat_event_logger_info = agentchat_event_logger.info
original_core_event_logger_info = core_event_logger.info
class LogHandler(logging.FileHandler):
def __init__(self, filename: str = "log.jsonl", print_message: bool = True) -> None:
super().__init__(filename, mode="w")
self.print_message = print_message
def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if AGENTCHAT_EVENT_LOGGER_NAME in record.name:
original_msg = record.msg
record.msg = json.dumps(
{
"timestamp": ts,
"source": record.msg.source,
"message": content_to_str(record.msg.content),
"type": record.msg.type,
}
)
super().emit(record)
record.msg = original_msg
elif CORE_EVENT_LOGGER_NAME in record.name:
if isinstance(record.msg, LLMCallEvent):
original_msg = record.msg
record.msg = json.dumps(
{
"timestamp": ts,
"prompt_tokens": record.msg.kwargs["prompt_tokens"],
"completion_tokens": record.msg.kwargs["completion_tokens"],
"type": "LLMCallEvent",
}
)
super().emit(record)
record.msg = original_msg
except Exception:
print("error in logHandler.emit", flush=True)
self.handleError(record)
def tee_print(*args, **kwargs):
# Get the current log file from the context.
log_file = current_log_file.get()
# Call the original print (goes to the console).
original_print(*args, **kwargs)
# Also write to the log file if one is set.
if log_file is not None:
sep = kwargs.get("sep", " ")
end = kwargs.get("end", "\n")
message = sep.join(map(str, args)) + end
log_file.write(message)
log_file.flush()
def team_specific_agentchat_event_logger_info(msg, *args, **kwargs):
team_id = current_team_id.get()
if team_id is not None:
# Get a logger with a team-specific name.
team_logger = logging.getLogger(f"{AGENTCHAT_EVENT_LOGGER_NAME}.team{team_id}")
team_logger.info(msg, *args, **kwargs)
else:
original_agentchat_event_logger_info(msg, *args, **kwargs)
def team_specific_core_event_logger_info(msg, *args, **kwargs):
team_id = current_team_id.get()
if team_id is not None:
# Get a logger with a team-specific name.
team_logger = logging.getLogger(f"{CORE_EVENT_LOGGER_NAME}.team{team_id}")
team_logger.info(msg, *args, **kwargs)
else:
original_core_event_logger_info(msg, *args, **kwargs)
# Monkey-patch the built-in print and event_logger.info methods with our team-specific versions.
builtins.print = tee_print
agentchat_event_logger.info = team_specific_agentchat_event_logger_info
core_event_logger.info = team_specific_core_event_logger_info
async def run_team(team: MagenticOneGroupChat, team_idx: int, task: str, cancellation_token: CancellationToken, logfile):
token_logfile = current_log_file.set(logfile)
token_team_id = current_team_id.set(team_idx)
try:
task_result = await Console(
team.run_stream(
task=task.strip(),
cancellation_token=cancellation_token
)
)
return team_idx, task_result
finally:
current_log_file.reset(token_logfile)
current_team_id.reset(token_team_id)
logfile.close()
async def aggregate_final_answer(task: str, client: ChatCompletionClient, team_results, source: str = "Aggregator", cancellation_token: Optional[CancellationToken] = None) -> str:
"""
team_results: {"team_key": TaskResult}
team_completion_order: The order in which the teams completed their tasks
"""
if len(team_results) == 1:
final_answer = list(team_results.values())[0].messages[-1].content
aggregator_logger.info(
f"{source} (Response):\n{final_answer}"
)
return final_answer
assert len(team_results) > 1
aggregator_messages_to_send = {team_id: deque() for team_id in team_results.keys()} # {team_id: context}
team_ids = list(team_results.keys())
current_round = 0
while (
not all(len(team_result.messages) == 0 for team_result in team_results.values())
and ((not resolve_model(client._create_args["model"]) in _MODEL_TOKEN_LIMITS) or client.remaining_tokens([m for messages in aggregator_messages_to_send.values() for m in messages])
> 2000)
):
team_idx = team_ids[current_round % len(team_ids)]
if len(team_results[team_idx].messages) > 0:
m = team_results[team_idx].messages[-1]
if isinstance(m, ToolCallRequestEvent | ToolCallExecutionEvent):
# Ignore tool call messages.
pass
elif isinstance(m, StopMessage | HandoffMessage):
aggregator_messages_to_send[team_idx].appendleft(UserMessage(content=m.to_model_text(), source=m.source))
elif m.source == "MagenticOneOrchestrator":
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
aggregator_messages_to_send[team_idx].appendleft(AssistantMessage(content=m.to_model_text(), source=m.source))
else:
assert isinstance(m, (TextMessage, MultiModalMessage, ToolCallSummaryMessage))
aggregator_messages_to_send[team_idx].appendleft(UserMessage(content=m.to_model_text(), source=m.source))
team_results[team_idx].messages.pop()
current_round += 1
# Log the messages to send
payload = ""
for team_idx, messages in aggregator_messages_to_send.items():
payload += f"\n{'*'*75} \n" f"Team #: {team_idx}" f"\n{'*'*75} \n"
for message in messages:
payload += f"\n{'-'*75} \n" f"{message.source}:\n" f"\n{message.content}\n"
payload += f"\n{'-'*75} \n" f"Team #{team_idx} stop reason:\n" f"\n{team_results[team_idx].stop_reason}\n"
payload += f"\n{'*'*75} \n"
aggregator_logger.info(f"{source} (Aggregator Messages):\n{payload}")
context: List[LLMMessage] = []
# Add the preamble
context.append(
UserMessage(
content=f"Earlier you were asked the following:\n\n{task}\n\nYour team then worked diligently to address that request. You have been provided with a collection of transcripts and stop reasons from {len(team_results)} different teams to the question. Your task is to carefully evaluate the correctness of each team's response by analyzing their respective transcripts and stop reasons. After considering all perspectives, provide a FINAL ANSWER to the question. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect.",
source=source,
)
)
for team_idx, aggregator_messages in aggregator_messages_to_send.items():
context.append(
UserMessage(
content=f"Transcript from Team #{team_idx}:",
source=source,
)
)
for message in aggregator_messages:
context.append(message)
context.append(
UserMessage(
content=f"Stop reason from Team #{team_idx}:",
source=source,
)
)
context.append(
UserMessage(
content=team_results[team_idx].stop_reason if team_results[team_idx].stop_reason else "No stop reason provided.",
source=source,
)
)
# ask for the final answer
context.append(
UserMessage(
content=f"""
Let's think step-by-step. Carefully review the conversation above, critically evaluate the correctness of each team's response, and then output a FINAL ANSWER to the question. The question is repeated here for convenience:
{task}
To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER]
Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
""".strip(),
source=source,
)
)
response = await client.create(context, cancellation_token=cancellation_token)
assert isinstance(response.content, str)
final_answer = re.sub(r"FINAL ANSWER:", "[FINAL ANSWER]:", response.content)
aggregator_logger.info(
f"{source} (Response):\n{final_answer}"
)
return re.sub(r"FINAL ANSWER:", "FINAL AGGREGATED ANSWER:", response.content)
async def main(num_teams: int, num_answers: int) -> None:
# Load model configuration and create the model client.
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
orchestrator_client = ChatCompletionClient.load_component(config["orchestrator_client"])
coder_client = ChatCompletionClient.load_component(config["coder_client"])
web_surfer_client = ChatCompletionClient.load_component(config["web_surfer_client"])
file_surfer_client = ChatCompletionClient.load_component(config["file_surfer_client"])
# Read the prompt
prompt = ""
with open("prompt.txt", "rt") as fh:
prompt = fh.read().strip()
filename = "__FILE_NAME__".strip()
# Prepare the prompt
filename_prompt = ""
if len(filename) > 0:
filename_prompt = f"The question is about a file, document or image, which can be accessed by the filename '{filename}' in the current working directory."
task = f"{prompt}\n\n{filename_prompt}"
# Reset logs directory (remove all files in it)
logs_dir = "logs"
if os.path.exists(logs_dir):
shutil.rmtree(logs_dir)
teams = []
async_tasks = []
tokens = []
for team_idx in range(num_teams):
# Set up the team
coder = MagenticOneCoderAgent(
"Assistant",
model_client = coder_client,
)
executor = CodeExecutorAgent("ComputerTerminal", code_executor=LocalCommandLineCodeExecutor())
file_surfer = FileSurfer(
name="FileSurfer",
model_client = file_surfer_client,
)
web_surfer = MultimodalWebSurfer(
name="WebSurfer",
model_client = web_surfer_client,
downloads_folder=os.getcwd(),
debug_dir=logs_dir,
to_save_screenshots=True,
)
team = MagenticOneGroupChat(
[coder, executor, file_surfer, web_surfer],
model_client=orchestrator_client,
max_turns=30,
final_answer_prompt= f""",
We have completed the following task:
{prompt}
The above messages contain the conversation that took place to complete the task.
Read the above conversation and output a FINAL ANSWER to the question.
To output the final answer, use the following template: FINAL ANSWER: [YOUR FINAL ANSWER]
Your FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
ADDITIONALLY, your FINAL ANSWER MUST adhere to any formatting instructions specified in the original question (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)
If you are asked for a number, express it numerically (i.e., with digits rather than words), don't use commas, and don't include units such as $ or percent signs unless specified otherwise.
If you are asked for a string, don't use articles or abbreviations (e.g. for cities), unless specified otherwise. Don't output any final sentence punctuation such as '.', '!', or '?'.
If you are asked for a comma separated list, apply the above rules depending on whether the elements are numbers or strings.
""".strip()
)
teams.append(team)
cancellation_token = CancellationToken()
tokens.append(cancellation_token)
logfile = open(f"console_log_{team_idx}.txt", "w")
team_agentchat_logger = logging.getLogger(f"{AGENTCHAT_EVENT_LOGGER_NAME}.team{team_idx}")
team_core_logger = logging.getLogger(f"{CORE_EVENT_LOGGER_NAME}.team{team_idx}")
team_log_handler = LogHandler(f"log_{team_idx}.jsonl", print_message=False)
team_agentchat_logger.addHandler(team_log_handler)
team_core_logger.addHandler(team_log_handler)
async_task = asyncio.create_task(
run_team(team, team_idx, task, cancellation_token, logfile)
)
async_tasks.append(async_task)
# Wait until at least num_answers tasks have completed.
team_results = {}
for future in asyncio.as_completed(async_tasks):
try:
team_id, result = await future
team_results[team_id] = result
except Exception as e:
# Optionally log exception.
print(f"Task raised an exception: {e}")
if len(team_results) >= num_answers:
break
# Cancel any pending teams.
for task, token in zip(async_tasks, tokens):
if not task.done():
token.cancel()
# Await all tasks to handle cancellation gracefully.
await asyncio.gather(*async_tasks, return_exceptions=True)
print("len(team_results):", len(team_results))
final_answer = await aggregate_final_answer(prompt, orchestrator_client, team_results)
print(final_answer)
if __name__ == "__main__":
num_teams = 3
num_answers = 3
agentchat_trace_logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler("trace.log", mode="w")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
agentchat_trace_logger.addHandler(file_handler)
core_event_logger.setLevel(logging.DEBUG)
agentchat_event_logger.setLevel(logging.DEBUG)
log_handler = LogHandler()
core_event_logger.addHandler(log_handler)
agentchat_event_logger.addHandler(log_handler)
# Create another logger for the aggregator
aggregator_logger = logging.getLogger("aggregator")
aggregator_logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("aggregator_log.txt", mode="w")
fh.setLevel(logging.DEBUG)
aggregator_logger.addHandler(fh)
asyncio.run(main(num_teams, num_answers))