mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-16 09:39:04 +00:00
Add AssistantAgent, deprecate CodingAssistantAgent and ToolUseAssistantAgent (#3960)
* Add AssistantAgent, deprecate CodingAssistantAgent and ToolUseAssistantAgent * Rename * Add note * Update uv * uf lock * Merge branch 'main' into assistant-agent * Update uv
This commit is contained in:
parent
69fc742537
commit
3fe0f9e97d
@ -1,3 +1,4 @@
|
|||||||
|
from ._assistant_agent import AssistantAgent
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
from ._code_executor_agent import CodeExecutorAgent
|
from ._code_executor_agent import CodeExecutorAgent
|
||||||
from ._coding_assistant_agent import CodingAssistantAgent
|
from ._coding_assistant_agent import CodingAssistantAgent
|
||||||
@ -5,6 +6,7 @@ from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseChatAgent",
|
"BaseChatAgent",
|
||||||
|
"AssistantAgent",
|
||||||
"CodeExecutorAgent",
|
"CodeExecutorAgent",
|
||||||
"CodingAssistantAgent",
|
"CodingAssistantAgent",
|
||||||
"ToolUseAssistantAgent",
|
"ToolUseAssistantAgent",
|
||||||
|
|||||||
@ -0,0 +1,140 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, List, Sequence
|
||||||
|
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
|
from autogen_core.components import FunctionCall
|
||||||
|
from autogen_core.components.models import (
|
||||||
|
AssistantMessage,
|
||||||
|
ChatCompletionClient,
|
||||||
|
FunctionExecutionResult,
|
||||||
|
FunctionExecutionResultMessage,
|
||||||
|
LLMMessage,
|
||||||
|
SystemMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from autogen_core.components.tools import FunctionTool, Tool
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from .. import EVENT_LOGGER_NAME
|
||||||
|
from ..messages import (
|
||||||
|
ChatMessage,
|
||||||
|
StopMessage,
|
||||||
|
TextMessage,
|
||||||
|
)
|
||||||
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
|
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallEvent(BaseModel):
|
||||||
|
"""A tool call event."""
|
||||||
|
|
||||||
|
tool_calls: List[FunctionCall]
|
||||||
|
"""The tool call message."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResultEvent(BaseModel):
|
||||||
|
"""A tool call result event."""
|
||||||
|
|
||||||
|
tool_call_results: List[FunctionExecutionResult]
|
||||||
|
"""The tool call result message."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantAgent(BaseChatAgent):
|
||||||
|
"""An agent that provides assistance with tool use.
|
||||||
|
|
||||||
|
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the agent.
|
||||||
|
model_client (ChatCompletionClient): The model client to use for inference.
|
||||||
|
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
|
||||||
|
description (str, optional): The description of the agent.
|
||||||
|
system_message (str, optional): The system message for the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
model_client: ChatCompletionClient,
|
||||||
|
*,
|
||||||
|
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
|
||||||
|
description: str = "An agent that provides assistance with ability to use tools.",
|
||||||
|
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||||
|
):
|
||||||
|
super().__init__(name=name, description=description)
|
||||||
|
self._model_client = model_client
|
||||||
|
self._system_messages = [SystemMessage(content=system_message)]
|
||||||
|
self._tools: List[Tool] = []
|
||||||
|
if tools is not None:
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, Tool):
|
||||||
|
self._tools.append(tool)
|
||||||
|
elif callable(tool):
|
||||||
|
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||||
|
description = tool.__doc__
|
||||||
|
else:
|
||||||
|
description = ""
|
||||||
|
self._tools.append(FunctionTool(tool, description=description))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||||
|
self._model_context: List[LLMMessage] = []
|
||||||
|
|
||||||
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||||
|
# Add messages to the model context.
|
||||||
|
for msg in messages:
|
||||||
|
# TODO: add special handling for handoff messages
|
||||||
|
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||||
|
|
||||||
|
# Generate an inference result based on the current model context.
|
||||||
|
llm_messages = self._system_messages + self._model_context
|
||||||
|
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
||||||
|
|
||||||
|
# Add the response to the model context.
|
||||||
|
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||||
|
|
||||||
|
# Run tool calls until the model produces a string response.
|
||||||
|
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||||
|
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
||||||
|
# Execute the tool calls.
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||||
|
)
|
||||||
|
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
||||||
|
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||||
|
# Generate an inference result based on the current model context.
|
||||||
|
result = await self._model_client.create(
|
||||||
|
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
||||||
|
)
|
||||||
|
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||||
|
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
# Detect stop request.
|
||||||
|
request_stop = "terminate" in result.content.strip().lower()
|
||||||
|
if request_stop:
|
||||||
|
return StopMessage(content=result.content, source=self.name)
|
||||||
|
|
||||||
|
return TextMessage(content=result.content, source=self.name)
|
||||||
|
|
||||||
|
async def _execute_tool_call(
|
||||||
|
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||||
|
) -> FunctionExecutionResult:
|
||||||
|
"""Execute a tool call and return the result."""
|
||||||
|
try:
|
||||||
|
if not self._tools:
|
||||||
|
raise ValueError("No tools are available.")
|
||||||
|
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
||||||
|
if tool is None:
|
||||||
|
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||||
|
arguments = json.loads(tool_call.arguments)
|
||||||
|
result = await tool.run_json(arguments, cancellation_token)
|
||||||
|
result_as_str = tool.return_value_as_string(result)
|
||||||
|
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
|
||||||
|
except Exception as e:
|
||||||
|
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
|
||||||
@ -1,20 +1,14 @@
|
|||||||
from typing import List, Sequence
|
import warnings
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
|
||||||
from autogen_core.components.models import (
|
from autogen_core.components.models import (
|
||||||
AssistantMessage,
|
|
||||||
ChatCompletionClient,
|
ChatCompletionClient,
|
||||||
LLMMessage,
|
|
||||||
SystemMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
from ._assistant_agent import AssistantAgent
|
||||||
from ._base_chat_agent import BaseChatAgent
|
|
||||||
|
|
||||||
|
|
||||||
class CodingAssistantAgent(BaseChatAgent):
|
class CodingAssistantAgent(AssistantAgent):
|
||||||
"""An agent that provides coding assistance using an LLM model client.
|
"""[DEPRECATED] An agent that provides coding assistance using an LLM model client.
|
||||||
|
|
||||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||||
"""
|
"""
|
||||||
@ -37,29 +31,10 @@ If the result indicates there is an error, fix the error and output the code aga
|
|||||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
||||||
Reply "TERMINATE" in the end when code has been executed and task is complete.""",
|
Reply "TERMINATE" in the end when code has been executed and task is complete.""",
|
||||||
):
|
):
|
||||||
super().__init__(name=name, description=description)
|
# Deprecation warning.
|
||||||
self._model_client = model_client
|
warnings.warn(
|
||||||
self._system_messages = [SystemMessage(content=system_message)]
|
"CodingAssistantAgent is deprecated. Use AssistantAgent instead.",
|
||||||
self._model_context: List[LLMMessage] = []
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
)
|
||||||
# Add messages to the model context and detect stopping.
|
super().__init__(name, model_client, description=description, system_message=system_message)
|
||||||
for msg in messages:
|
|
||||||
if not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
|
|
||||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
|
||||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
|
||||||
llm_messages = self._system_messages + self._model_context
|
|
||||||
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
|
|
||||||
assert isinstance(result.content, str)
|
|
||||||
|
|
||||||
# Add the response to the model context.
|
|
||||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
|
||||||
|
|
||||||
# Detect stop request.
|
|
||||||
request_stop = "terminate" in result.content.strip().lower()
|
|
||||||
if request_stop:
|
|
||||||
return StopMessage(content=result.content, source=self.name)
|
|
||||||
|
|
||||||
return TextMessage(content=result.content, source=self.name)
|
|
||||||
|
|||||||
@ -1,53 +1,20 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, List, Sequence
|
import warnings
|
||||||
|
from typing import Any, Awaitable, Callable, List
|
||||||
|
|
||||||
from autogen_core.base import CancellationToken
|
|
||||||
from autogen_core.components import FunctionCall
|
|
||||||
from autogen_core.components.models import (
|
from autogen_core.components.models import (
|
||||||
AssistantMessage,
|
|
||||||
ChatCompletionClient,
|
ChatCompletionClient,
|
||||||
FunctionExecutionResult,
|
|
||||||
FunctionExecutionResultMessage,
|
|
||||||
LLMMessage,
|
|
||||||
SystemMessage,
|
|
||||||
UserMessage,
|
|
||||||
)
|
)
|
||||||
from autogen_core.components.tools import FunctionTool, Tool
|
from autogen_core.components.tools import Tool
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from .. import EVENT_LOGGER_NAME
|
from .. import EVENT_LOGGER_NAME
|
||||||
from ..messages import (
|
from ._assistant_agent import AssistantAgent
|
||||||
ChatMessage,
|
|
||||||
StopMessage,
|
|
||||||
TextMessage,
|
|
||||||
)
|
|
||||||
from ._base_chat_agent import BaseChatAgent
|
|
||||||
|
|
||||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||||
|
|
||||||
|
|
||||||
class ToolCallEvent(BaseModel):
|
class ToolUseAssistantAgent(AssistantAgent):
|
||||||
"""A tool call event."""
|
"""[DEPRECATED] An agent that provides assistance with tool use.
|
||||||
|
|
||||||
tool_calls: List[FunctionCall]
|
|
||||||
"""The tool call message."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCallResultEvent(BaseModel):
|
|
||||||
"""A tool call result event."""
|
|
||||||
|
|
||||||
tool_call_results: List[FunctionExecutionResult]
|
|
||||||
"""The tool call result message."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolUseAssistantAgent(BaseChatAgent):
|
|
||||||
"""An agent that provides assistance with tool use.
|
|
||||||
|
|
||||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||||
|
|
||||||
@ -68,72 +35,12 @@ class ToolUseAssistantAgent(BaseChatAgent):
|
|||||||
description: str = "An agent that provides assistance with ability to use tools.",
|
description: str = "An agent that provides assistance with ability to use tools.",
|
||||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||||
):
|
):
|
||||||
super().__init__(name=name, description=description)
|
# Deprecation warning.
|
||||||
self._model_client = model_client
|
warnings.warn(
|
||||||
self._system_messages = [SystemMessage(content=system_message)]
|
"ToolUseAssistantAgent is deprecated. Use AssistantAgent instead.",
|
||||||
self._tools: List[Tool] = []
|
DeprecationWarning,
|
||||||
for tool in registered_tools:
|
stacklevel=2,
|
||||||
if isinstance(tool, Tool):
|
|
||||||
self._tools.append(tool)
|
|
||||||
elif callable(tool):
|
|
||||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
|
||||||
description = tool.__doc__
|
|
||||||
else:
|
|
||||||
description = ""
|
|
||||||
self._tools.append(FunctionTool(tool, description=description))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
|
||||||
self._model_context: List[LLMMessage] = []
|
|
||||||
|
|
||||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
|
||||||
# Add messages to the model context.
|
|
||||||
for msg in messages:
|
|
||||||
# TODO: add special handling for handoff messages
|
|
||||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
|
||||||
|
|
||||||
# Generate an inference result based on the current model context.
|
|
||||||
llm_messages = self._system_messages + self._model_context
|
|
||||||
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
|
||||||
|
|
||||||
# Add the response to the model context.
|
|
||||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
|
||||||
|
|
||||||
# Run tool calls until the model produces a string response.
|
|
||||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
|
||||||
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
|
||||||
# Execute the tool calls.
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
|
||||||
)
|
)
|
||||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
super().__init__(
|
||||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
name, model_client, tools=registered_tools, description=description, system_message=system_message
|
||||||
# Generate an inference result based on the current model context.
|
|
||||||
result = await self._model_client.create(
|
|
||||||
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
|
||||||
)
|
)
|
||||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
|
||||||
|
|
||||||
assert isinstance(result.content, str)
|
|
||||||
# Detect stop request.
|
|
||||||
request_stop = "terminate" in result.content.strip().lower()
|
|
||||||
if request_stop:
|
|
||||||
return StopMessage(content=result.content, source=self.name)
|
|
||||||
|
|
||||||
return TextMessage(content=result.content, source=self.name)
|
|
||||||
|
|
||||||
async def _execute_tool_call(
|
|
||||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
|
||||||
) -> FunctionExecutionResult:
|
|
||||||
"""Execute a tool call and return the result."""
|
|
||||||
try:
|
|
||||||
if not self._tools:
|
|
||||||
raise ValueError("No tools are available.")
|
|
||||||
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
|
||||||
if tool is None:
|
|
||||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
|
||||||
arguments = json.loads(tool_call.arguments)
|
|
||||||
result = await tool.run_json(arguments, cancellation_token)
|
|
||||||
result_as_str = tool.return_value_as_string(result)
|
|
||||||
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
|
|
||||||
except Exception as e:
|
|
||||||
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
|
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||||
from ..messages import ChatMessage, StopMessage, TextMessage
|
from ..messages import ChatMessage, StopMessage, TextMessage
|
||||||
from ..teams._events import (
|
from ..teams._events import (
|
||||||
GroupChatPublishEvent,
|
GroupChatPublishEvent,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from dataclasses import asdict, is_dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
|
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||||
from ..teams._events import (
|
from ..teams._events import (
|
||||||
GroupChatPublishEvent,
|
GroupChatPublishEvent,
|
||||||
GroupChatSelectSpeakerEvent,
|
GroupChatSelectSpeakerEvent,
|
||||||
|
|||||||
@ -3,10 +3,10 @@ import json
|
|||||||
from typing import Any, AsyncGenerator, List
|
from typing import Any, AsyncGenerator, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
from autogen_agentchat.agents import AssistantAgent
|
||||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
from autogen_agentchat.messages import StopMessage, TextMessage
|
||||||
from autogen_core.components.models import OpenAIChatCompletionClient
|
|
||||||
from autogen_core.components.tools import FunctionTool
|
from autogen_core.components.tools import FunctionTool
|
||||||
|
from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
@ -42,7 +42,7 @@ async def _echo_function(input: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
model = "gpt-4o-2024-05-13"
|
model = "gpt-4o-2024-05-13"
|
||||||
chat_completions = [
|
chat_completions = [
|
||||||
ChatCompletion(
|
ChatCompletion(
|
||||||
@ -97,10 +97,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||||||
]
|
]
|
||||||
mock = _MockChatCompletion(chat_completions)
|
mock = _MockChatCompletion(chat_completions)
|
||||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
tool_use_agent = ToolUseAssistantAgent(
|
tool_use_agent = AssistantAgent(
|
||||||
"tool_use_agent",
|
"tool_use_agent",
|
||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
registered_tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||||
)
|
)
|
||||||
result = await tool_use_agent.run("task")
|
result = await tool_use_agent.run("task")
|
||||||
assert len(result.messages) == 3
|
assert len(result.messages) == 3
|
||||||
@ -7,10 +7,9 @@ from typing import Any, AsyncGenerator, List, Sequence
|
|||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||||
from autogen_agentchat.agents import (
|
from autogen_agentchat.agents import (
|
||||||
|
AssistantAgent,
|
||||||
BaseChatAgent,
|
BaseChatAgent,
|
||||||
CodeExecutorAgent,
|
CodeExecutorAgent,
|
||||||
CodingAssistantAgent,
|
|
||||||
ToolUseAssistantAgent,
|
|
||||||
)
|
)
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import (
|
from autogen_agentchat.messages import (
|
||||||
@ -131,7 +130,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
code_executor_agent = CodeExecutorAgent(
|
code_executor_agent = CodeExecutorAgent(
|
||||||
"code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
"code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
||||||
)
|
)
|
||||||
coding_assistant_agent = CodingAssistantAgent(
|
coding_assistant_agent = AssistantAgent(
|
||||||
"coding_assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")
|
"coding_assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")
|
||||||
)
|
)
|
||||||
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
||||||
@ -211,10 +210,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||||||
mock = _MockChatCompletion(chat_completions)
|
mock = _MockChatCompletion(chat_completions)
|
||||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
tool = FunctionTool(_pass_function, name="pass", description="pass function")
|
tool = FunctionTool(_pass_function, name="pass", description="pass function")
|
||||||
tool_use_agent = ToolUseAssistantAgent(
|
tool_use_agent = AssistantAgent(
|
||||||
"tool_use_agent",
|
"tool_use_agent",
|
||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
registered_tools=[tool],
|
tools=[tool],
|
||||||
)
|
)
|
||||||
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
||||||
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
||||||
|
|||||||
2553
python/uv.lock
generated
2553
python/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user