2024-06-07 13:33:51 -07:00
""" This is an example of simulating a chess game with two agents
that play against each other , using tools to reason about the game state
2025-01-07 16:06:14 -08:00
and make moves . The agents subscribe to the default topic and publish their
moves to the default topic . """
2024-06-07 13:33:51 -07:00
import argparse
import asyncio
import logging
2025-01-31 14:25:29 -08:00
import yaml
2025-01-07 15:31:29 -08:00
from typing import Annotated , Any , Dict , List , Literal
2024-06-07 13:33:51 -07:00
2024-12-04 16:23:20 -08:00
from autogen_core import (
AgentId ,
AgentRuntime ,
DefaultTopicId ,
2025-01-07 15:31:29 -08:00
MessageContext ,
RoutedAgent ,
2024-12-04 16:23:20 -08:00
SingleThreadedAgentRuntime ,
2025-01-07 15:31:29 -08:00
default_subscription ,
message_handler ,
2024-12-04 16:23:20 -08:00
)
2025-01-07 15:31:29 -08:00
from autogen_core . model_context import BufferedChatCompletionContext , ChatCompletionContext
2025-01-07 16:06:14 -08:00
from autogen_core . models import (
ChatCompletionClient ,
LLMMessage ,
SystemMessage ,
UserMessage ,
)
2025-01-07 15:31:29 -08:00
from autogen_core . tool_agent import ToolAgent , tool_agent_caller_loop
from autogen_core . tools import FunctionTool , Tool , ToolSchema
2024-06-08 01:27:27 -07:00
from chess import BLACK , SQUARE_NAMES , WHITE , Board , Move
2024-06-07 13:33:51 -07:00
from chess import piece_name as get_piece_name
2025-01-07 15:31:29 -08:00
from pydantic import BaseModel
class TextMessage ( BaseModel ) :
source : str
content : str
@default_subscription
2025-01-07 16:06:14 -08:00
class PlayerAgent ( RoutedAgent ) :
2025-01-07 15:31:29 -08:00
def __init__ (
self ,
description : str ,
instructions : str ,
model_client : ChatCompletionClient ,
model_context : ChatCompletionContext ,
tool_schema : List [ ToolSchema ] ,
tool_agent_type : str ,
) - > None :
super ( ) . __init__ ( description = description )
self . _system_messages : List [ LLMMessage ] = [ SystemMessage ( content = instructions ) ]
self . _model_client = model_client
self . _tool_schema = tool_schema
self . _tool_agent_id = AgentId ( tool_agent_type , self . id . key )
self . _model_context = model_context
@message_handler
async def handle_message ( self , message : TextMessage , ctx : MessageContext ) - > None :
# Add the user message to the model context.
await self . _model_context . add_message ( UserMessage ( content = message . content , source = message . source ) )
# Run the caller loop to handle tool calls.
messages = await tool_agent_caller_loop (
self ,
tool_agent_id = self . _tool_agent_id ,
model_client = self . _model_client ,
2025-01-07 16:06:14 -08:00
input_messages = self . _system_messages + ( await self . _model_context . get_messages ( ) ) ,
2025-01-07 15:31:29 -08:00
tool_schema = self . _tool_schema ,
cancellation_token = ctx . cancellation_token ,
)
# Add the assistant message to the model context.
2025-01-07 16:06:14 -08:00
for msg in messages :
await self . _model_context . add_message ( msg )
2025-01-07 15:31:29 -08:00
# Publish the final response.
2025-01-07 16:06:14 -08:00
assert isinstance ( messages [ - 1 ] . content , str )
2025-01-07 15:31:29 -08:00
await self . publish_message ( TextMessage ( content = messages [ - 1 ] . content , source = self . id . type ) , DefaultTopicId ( ) )
2024-06-25 13:23:29 -07:00
2024-06-07 13:33:51 -07:00
2024-06-08 01:27:27 -07:00
def validate_turn ( board : Board , player : Literal [ " white " , " black " ] ) - > None :
""" Validate that it is the player ' s turn to move. """
last_move = board . peek ( ) if board . move_stack else None
if last_move is not None :
if player == " white " and board . color_at ( last_move . to_square ) == WHITE :
raise ValueError ( " It is not your turn to move. Wait for black to move. " )
if player == " black " and board . color_at ( last_move . to_square ) == BLACK :
raise ValueError ( " It is not your turn to move. Wait for white to move. " )
elif last_move is None and player != " white " :
raise ValueError ( " It is not your turn to move. Wait for white to move first. " )
def get_legal_moves (
board : Board , player : Literal [ " white " , " black " ]
) - > Annotated [ str , " A list of legal moves in UCI format. " ] :
""" Get legal moves for the given player. """
validate_turn ( board , player )
legal_moves = list ( board . legal_moves )
if player == " black " :
legal_moves = [ move for move in legal_moves if board . color_at ( move . from_square ) == BLACK ]
elif player == " white " :
legal_moves = [ move for move in legal_moves if board . color_at ( move . from_square ) == WHITE ]
else :
raise ValueError ( " Invalid player, must be either ' black ' or ' white ' . " )
if not legal_moves :
return " No legal moves. The game is over. "
return " Possible moves are: " + " , " . join ( [ move . uci ( ) for move in legal_moves ] )
def get_board ( board : Board ) - > str :
2024-11-18 21:48:00 -06:00
""" Get the current board state. """
2024-06-08 01:27:27 -07:00
return str ( board )
def make_move (
board : Board ,
player : Literal [ " white " , " black " ] ,
thinking : Annotated [ str , " Thinking for the move. " ] ,
move : Annotated [ str , " A move in UCI format. " ] ,
) - > Annotated [ str , " Result of the move. " ] :
""" Make a move on the board. """
validate_turn ( board , player )
2024-11-18 21:48:00 -06:00
new_move = Move . from_uci ( move )
board . push ( new_move )
2024-06-08 01:27:27 -07:00
# Print the move.
print ( " - " * 50 )
print ( " Player: " , player )
2024-11-18 21:48:00 -06:00
print ( " Move: " , new_move . uci ( ) )
2024-06-08 01:27:27 -07:00
print ( " Thinking: " , thinking )
print ( " Board: " )
print ( board . unicode ( borders = True ) )
# Get the piece name.
2024-11-18 21:48:00 -06:00
piece = board . piece_at ( new_move . to_square )
2024-06-08 01:27:27 -07:00
assert piece is not None
piece_symbol = piece . unicode_symbol ( )
piece_name = get_piece_name ( piece . piece_type )
if piece_symbol . isupper ( ) :
piece_name = piece_name . capitalize ( )
2024-11-18 21:48:00 -06:00
return f " Moved { piece_name } ( { piece_symbol } ) from { SQUARE_NAMES [ new_move . from_square ] } to { SQUARE_NAMES [ new_move . to_square ] } . "
2024-06-08 01:27:27 -07:00
2025-01-07 15:31:29 -08:00
async def chess_game ( runtime : AgentRuntime , model_config : Dict [ str , Any ] ) - > None : # type: ignore
2024-06-07 13:33:51 -07:00
""" Create agents for a chess game and return the group chat. """
# Create the board.
board = Board ( )
2024-06-08 01:27:27 -07:00
# Create tools for each player.
def get_legal_moves_black ( ) - > str :
return get_legal_moves ( board , " black " )
def get_legal_moves_white ( ) - > str :
return get_legal_moves ( board , " white " )
def make_move_black (
thinking : Annotated [ str , " Thinking for the move " ] ,
move : Annotated [ str , " A move in UCI format " ] ,
) - > str :
return make_move ( board , " black " , thinking , move )
def make_move_white (
thinking : Annotated [ str , " Thinking for the move " ] ,
move : Annotated [ str , " A move in UCI format " ] ,
) - > str :
return make_move ( board , " white " , thinking , move )
def get_board_text ( ) - > Annotated [ str , " The current board state " ] :
return get_board ( board )
2025-01-07 15:31:29 -08:00
black_tools : List [ Tool ] = [
2024-06-08 01:27:27 -07:00
FunctionTool (
get_legal_moves_black ,
name = " get_legal_moves " ,
description = " Get legal moves. " ,
) ,
FunctionTool (
make_move_black ,
name = " make_move " ,
description = " Make a move. " ,
) ,
FunctionTool (
get_board_text ,
name = " get_board " ,
description = " Get the current board state. " ,
) ,
]
2025-01-07 15:31:29 -08:00
white_tools : List [ Tool ] = [
2024-06-08 01:27:27 -07:00
FunctionTool (
get_legal_moves_white ,
name = " get_legal_moves " ,
description = " Get legal moves. " ,
) ,
FunctionTool (
make_move_white ,
name = " make_move " ,
description = " Make a move. " ,
) ,
FunctionTool (
get_board_text ,
name = " get_board " ,
description = " Get the current board state. " ,
) ,
]
2024-06-07 13:33:51 -07:00
2025-01-07 15:31:29 -08:00
model_client = ChatCompletionClient . load_component ( model_config )
# Register the agents.
await ToolAgent . register (
runtime ,
2025-01-07 16:06:14 -08:00
" PlayerBlackToolAgent " ,
lambda : ToolAgent ( description = " Tool agent for chess game. " , tools = black_tools ) ,
)
await ToolAgent . register (
runtime ,
" PlayerWhiteToolAgent " ,
lambda : ToolAgent ( description = " Tool agent for chess game. " , tools = white_tools ) ,
2025-01-07 15:31:29 -08:00
)
2025-01-07 16:06:14 -08:00
await PlayerAgent . register (
2024-11-18 21:48:00 -06:00
runtime ,
2024-06-18 14:53:18 -04:00
" PlayerBlack " ,
2025-01-07 16:06:14 -08:00
lambda : PlayerAgent (
2024-06-18 14:53:18 -04:00
description = " Player playing black. " ,
2025-01-07 16:06:14 -08:00
instructions = " You are a chess player and you play as black. Use the tool ' get_board ' and ' get_legal_moves ' to get the legal moves and ' make_move ' to make a move. " ,
2025-01-07 15:31:29 -08:00
model_client = model_client ,
2024-09-12 05:35:35 +08:00
model_context = BufferedChatCompletionContext ( buffer_size = 10 ) ,
2025-01-07 15:31:29 -08:00
tool_schema = [ tool . schema for tool in black_tools ] ,
2025-01-07 16:06:14 -08:00
tool_agent_type = " PlayerBlackToolAgent " ,
2024-06-18 14:53:18 -04:00
) ,
2024-06-07 13:33:51 -07:00
)
2024-11-18 21:48:00 -06:00
2025-01-07 16:06:14 -08:00
await PlayerAgent . register (
2024-11-18 21:48:00 -06:00
runtime ,
2024-06-18 14:53:18 -04:00
" PlayerWhite " ,
2025-01-07 16:06:14 -08:00
lambda : PlayerAgent (
2024-06-18 14:53:18 -04:00
description = " Player playing white. " ,
2025-01-07 16:06:14 -08:00
instructions = " You are a chess player and you play as white. Use the tool ' get_board ' and ' get_legal_moves ' to get the legal moves and ' make_move ' to make a move. " ,
2025-01-07 15:31:29 -08:00
model_client = model_client ,
2024-09-12 05:35:35 +08:00
model_context = BufferedChatCompletionContext ( buffer_size = 10 ) ,
2025-01-07 15:31:29 -08:00
tool_schema = [ tool . schema for tool in white_tools ] ,
2025-01-07 16:06:14 -08:00
tool_agent_type = " PlayerWhiteToolAgent " ,
2024-06-18 14:53:18 -04:00
) ,
2024-06-07 13:33:51 -07:00
)
2025-01-07 15:31:29 -08:00
async def main ( model_config : Dict [ str , Any ] ) - > None :
2024-11-18 21:48:00 -06:00
""" Main Entrypoint. """
2024-06-07 13:33:51 -07:00
runtime = SingleThreadedAgentRuntime ( )
2025-01-07 15:31:29 -08:00
await chess_game ( runtime , model_config )
2024-08-21 13:59:59 -07:00
runtime . start ( )
2024-11-18 21:48:00 -06:00
# Publish an initial message to trigger the group chat manager to start
# orchestration.
2025-01-07 15:31:29 -08:00
# Send an initial message to player white to start the game.
await runtime . send_message (
2025-01-07 16:06:14 -08:00
TextMessage ( content = " Game started, white player your move. " , source = " System " ) ,
2025-01-07 15:31:29 -08:00
AgentId ( " PlayerWhite " , " default " ) ,
2024-11-18 21:48:00 -06:00
)
2024-08-21 13:59:59 -07:00
await runtime . stop_when_idle ( )
2024-06-07 13:33:51 -07:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( description = " Run a chess game between two agents. " )
2024-06-08 01:27:27 -07:00
parser . add_argument ( " --verbose " , action = " store_true " , help = " Enable verbose logging. " )
2025-01-07 15:31:29 -08:00
parser . add_argument (
2025-01-31 14:25:29 -08:00
" --model-config " , type = str , help = " Path to the model configuration file. " , default = " model_config.yml "
2025-01-07 15:31:29 -08:00
)
2024-06-07 13:33:51 -07:00
args = parser . parse_args ( )
2024-06-08 01:27:27 -07:00
if args . verbose :
logging . basicConfig ( level = logging . WARNING )
2024-08-28 12:47:04 -04:00
logging . getLogger ( " autogen_core " ) . setLevel ( logging . DEBUG )
2024-06-17 17:54:27 -07:00
handler = logging . FileHandler ( " chess_game.log " )
2024-08-28 12:47:04 -04:00
logging . getLogger ( " autogen_core " ) . addHandler ( handler )
2024-06-08 01:27:27 -07:00
2025-01-07 15:31:29 -08:00
with open ( args . model_config , " r " ) as f :
2025-01-31 14:25:29 -08:00
model_config = yaml . safe_load ( f )
2025-01-07 15:31:29 -08:00
asyncio . run ( main ( model_config ) )