mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-03 12:08:08 +00:00
Refactor agentchat +implement base chat agent run method (#3913)
This commit is contained in:
parent
8f6dc4e1dd
commit
1812cc068d
@ -1,8 +1,11 @@
|
|||||||
|
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
|
||||||
from ._code_executor_agent import CodeExecutorAgent
|
from ._code_executor_agent import CodeExecutorAgent
|
||||||
from ._coding_assistant_agent import CodingAssistantAgent
|
from ._coding_assistant_agent import CodingAssistantAgent
|
||||||
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BaseChatAgent",
|
||||||
|
"BaseToolUseChatAgent",
|
||||||
"CodeExecutorAgent",
|
"CodeExecutorAgent",
|
||||||
"CodingAssistantAgent",
|
"CodingAssistantAgent",
|
||||||
"ToolUseAssistantAgent",
|
"ToolUseAssistantAgent",
|
||||||
|
@ -4,12 +4,13 @@ from typing import List, Sequence
|
|||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components.tools import Tool
|
from autogen_core.components.tools import Tool
|
||||||
|
|
||||||
|
from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
|
||||||
from ..messages import ChatMessage
|
from ..messages import ChatMessage
|
||||||
from ._base_task import TaskResult, TaskRunner
|
from ..teams import RoundRobinGroupChat
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAgent(TaskRunner, ABC):
|
class BaseChatAgent(ChatAgent, ABC):
|
||||||
"""Base class for a chat agent that can participant in a team."""
|
"""Base class for a chat agent."""
|
||||||
|
|
||||||
def __init__(self, name: str, description: str) -> None:
|
def __init__(self, name: str, description: str) -> None:
|
||||||
self._name = name
|
self._name = name
|
||||||
@ -36,13 +37,23 @@ class BaseChatAgent(TaskRunner, ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self, task: str, *, source: str = "user", cancellation_token: CancellationToken | None = None
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
termination_condition: TerminationCondition | None = None,
|
||||||
) -> TaskResult:
|
) -> TaskResult:
|
||||||
# TODO: Implement this method.
|
"""Run the agent with the given task and return the result."""
|
||||||
raise NotImplementedError
|
group_chat = RoundRobinGroupChat(participants=[self])
|
||||||
|
result = await group_chat.run(
|
||||||
|
task=task,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
|
termination_condition=termination_condition,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BaseToolUseChatAgent(BaseChatAgent):
|
class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
|
||||||
"""Base class for a chat agent that can use tools.
|
"""Base class for a chat agent that can use tools.
|
||||||
|
|
||||||
Subclass this base class to create an agent class that uses tools by returning
|
Subclass this base class to create an agent class that uses tools by returning
|
@ -3,8 +3,8 @@ from typing import List, Sequence
|
|||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
|
||||||
|
|
||||||
from ..base import BaseChatAgent
|
|
||||||
from ..messages import ChatMessage, TextMessage
|
from ..messages import ChatMessage, TextMessage
|
||||||
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
|
|
||||||
class CodeExecutorAgent(BaseChatAgent):
|
class CodeExecutorAgent(BaseChatAgent):
|
||||||
|
@ -9,8 +9,8 @@ from autogen_core.components.models import (
|
|||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..base import BaseChatAgent
|
|
||||||
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||||
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
|
|
||||||
class CodingAssistantAgent(BaseChatAgent):
|
class CodingAssistantAgent(BaseChatAgent):
|
||||||
|
@ -12,7 +12,6 @@ from autogen_core.components.models import (
|
|||||||
)
|
)
|
||||||
from autogen_core.components.tools import FunctionTool, Tool
|
from autogen_core.components.tools import FunctionTool, Tool
|
||||||
|
|
||||||
from ..base import BaseToolUseChatAgent
|
|
||||||
from ..messages import (
|
from ..messages import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
MultiModalMessage,
|
MultiModalMessage,
|
||||||
@ -21,6 +20,7 @@ from ..messages import (
|
|||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolCallResultMessage,
|
ToolCallResultMessage,
|
||||||
)
|
)
|
||||||
|
from ._base_chat_agent import BaseToolUseChatAgent
|
||||||
|
|
||||||
|
|
||||||
class ToolUseAssistantAgent(BaseToolUseChatAgent):
|
class ToolUseAssistantAgent(BaseToolUseChatAgent):
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
|
from ._chat_agent import ChatAgent, ToolUseChatAgent
|
||||||
from ._base_task import TaskResult, TaskRunner
|
from ._task import TaskResult, TaskRunner
|
||||||
from ._base_team import Team
|
from ._team import Team
|
||||||
from ._base_termination import TerminatedException, TerminationCondition
|
from ._termination import TerminatedException, TerminationCondition
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseChatAgent",
|
"ChatAgent",
|
||||||
"BaseToolUseChatAgent",
|
"ToolUseChatAgent",
|
||||||
"Team",
|
"Team",
|
||||||
"TerminatedException",
|
"TerminatedException",
|
||||||
"TerminationCondition",
|
"TerminationCondition",
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
from typing import Protocol
|
|
||||||
|
|
||||||
from ._base_task import TaskResult, TaskRunner
|
|
||||||
from ._base_termination import TerminationCondition
|
|
||||||
|
|
||||||
|
|
||||||
class Team(TaskRunner, Protocol):
|
|
||||||
async def run(self, task: str, *, termination_condition: TerminationCondition | None = None) -> TaskResult:
|
|
||||||
"""Run the team on a given task until the termination condition is met."""
|
|
||||||
...
|
|
@ -0,0 +1,50 @@
|
|||||||
|
from typing import List, Protocol, Sequence, runtime_checkable
|
||||||
|
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
|
from autogen_core.components.tools import Tool
|
||||||
|
|
||||||
|
from ..messages import ChatMessage
|
||||||
|
from ._task import TaskResult, TaskRunner
|
||||||
|
from ._termination import TerminationCondition
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ChatAgent(TaskRunner, Protocol):
|
||||||
|
"""Protocol for a chat agent."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""The name of the agent. This is used by team to uniquely identify
|
||||||
|
the agent. It should be unique within the team."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
"""The description of the agent. This is used by team to
|
||||||
|
make decisions about which agents to use. The description should
|
||||||
|
describe the agent's capabilities and how to interact with it."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||||
|
"""Handle incoming messages and return a response message."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
termination_condition: TerminationCondition | None = None,
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Run the agent with the given task and return the result."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ToolUseChatAgent(ChatAgent, Protocol):
|
||||||
|
"""Protocol for a chat agent that can use tools."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registered_tools(self) -> List[Tool]:
|
||||||
|
"""The list of tools that the agent can use."""
|
||||||
|
...
|
@ -1,7 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Protocol, Sequence
|
from typing import Protocol, Sequence
|
||||||
|
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
from ..messages import ChatMessage
|
from ..messages import ChatMessage
|
||||||
|
from ._termination import TerminationCondition
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -15,6 +18,12 @@ class TaskResult:
|
|||||||
class TaskRunner(Protocol):
|
class TaskRunner(Protocol):
|
||||||
"""A task runner."""
|
"""A task runner."""
|
||||||
|
|
||||||
async def run(self, task: str) -> TaskResult:
|
async def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
termination_condition: TerminationCondition | None = None,
|
||||||
|
) -> TaskResult:
|
||||||
"""Run the task."""
|
"""Run the task."""
|
||||||
...
|
...
|
@ -0,0 +1,18 @@
|
|||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from autogen_core.base import CancellationToken
|
||||||
|
|
||||||
|
from ._task import TaskResult, TaskRunner
|
||||||
|
from ._termination import TerminationCondition
|
||||||
|
|
||||||
|
|
||||||
|
class Team(TaskRunner, Protocol):
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
termination_condition: TerminationCondition | None = None,
|
||||||
|
) -> TaskResult:
|
||||||
|
"""Run the team on a given task until the termination condition is met."""
|
||||||
|
...
|
@ -0,0 +1,7 @@
|
|||||||
|
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MaxMessageTermination",
|
||||||
|
"TextMentionTermination",
|
||||||
|
"StopMessageTermination",
|
||||||
|
]
|
@ -1,11 +1,7 @@
|
|||||||
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
|
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
|
||||||
from ._group_chat._selector_group_chat import SelectorGroupChat
|
from ._group_chat._selector_group_chat import SelectorGroupChat
|
||||||
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MaxMessageTermination",
|
|
||||||
"TextMentionTermination",
|
|
||||||
"StopMessageTermination",
|
|
||||||
"RoundRobinGroupChat",
|
"RoundRobinGroupChat",
|
||||||
"SelectorGroupChat",
|
"SelectorGroupChat",
|
||||||
]
|
]
|
||||||
|
@ -8,7 +8,7 @@ from autogen_core.components.models import FunctionExecutionResult
|
|||||||
from autogen_core.components.tool_agent import ToolException
|
from autogen_core.components.tool_agent import ToolException
|
||||||
|
|
||||||
from ... import EVENT_LOGGER_NAME
|
from ... import EVENT_LOGGER_NAME
|
||||||
from ...base import BaseChatAgent
|
from ...base import ChatAgent
|
||||||
from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||||
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
|
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
|
||||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||||
@ -27,7 +27,7 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
|
|||||||
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
|
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
|
def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None:
|
||||||
super().__init__(description=agent.description)
|
super().__init__(description=agent.description)
|
||||||
self._parent_topic_type = parent_topic_type
|
self._parent_topic_type = parent_topic_type
|
||||||
self._agent = agent
|
self._agent = agent
|
||||||
|
@ -3,12 +3,20 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
from autogen_core.application import SingleThreadedAgentRuntime
|
from autogen_core.application import SingleThreadedAgentRuntime
|
||||||
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
|
from autogen_core.base import (
|
||||||
|
AgentId,
|
||||||
|
AgentInstantiationContext,
|
||||||
|
AgentRuntime,
|
||||||
|
AgentType,
|
||||||
|
CancellationToken,
|
||||||
|
MessageContext,
|
||||||
|
TopicId,
|
||||||
|
)
|
||||||
from autogen_core.components import ClosureAgent, TypeSubscription
|
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||||
from autogen_core.components.tool_agent import ToolAgent
|
from autogen_core.components.tool_agent import ToolAgent
|
||||||
from autogen_core.components.tools import Tool
|
from autogen_core.components.tools import Tool
|
||||||
|
|
||||||
from ...base import BaseChatAgent, BaseToolUseChatAgent, TaskResult, Team, TerminationCondition
|
from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent
|
||||||
from ...messages import ChatMessage, TextMessage
|
from ...messages import ChatMessage, TextMessage
|
||||||
from .._events import ContentPublishEvent, ContentRequestEvent
|
from .._events import ContentPublishEvent, ContentRequestEvent
|
||||||
from ._base_chat_agent_container import BaseChatAgentContainer
|
from ._base_chat_agent_container import BaseChatAgentContainer
|
||||||
@ -22,13 +30,13 @@ class BaseGroupChat(Team, ABC):
|
|||||||
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
|
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, participants: List[BaseChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
|
def __init__(self, participants: List[ChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
|
||||||
if len(participants) == 0:
|
if len(participants) == 0:
|
||||||
raise ValueError("At least one participant is required.")
|
raise ValueError("At least one participant is required.")
|
||||||
if len(participants) != len(set(participant.name for participant in participants)):
|
if len(participants) != len(set(participant.name for participant in participants)):
|
||||||
raise ValueError("The participant names must be unique.")
|
raise ValueError("The participant names must be unique.")
|
||||||
for participant in participants:
|
for participant in participants:
|
||||||
if isinstance(participant, BaseToolUseChatAgent) and not participant.registered_tools:
|
if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
|
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
|
||||||
)
|
)
|
||||||
@ -47,7 +55,7 @@ class BaseGroupChat(Team, ABC):
|
|||||||
) -> Callable[[], BaseGroupChatManager]: ...
|
) -> Callable[[], BaseGroupChatManager]: ...
|
||||||
|
|
||||||
def _create_participant_factory(
|
def _create_participant_factory(
|
||||||
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None
|
self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None
|
||||||
) -> Callable[[], BaseChatAgentContainer]:
|
) -> Callable[[], BaseChatAgentContainer]:
|
||||||
def _factory() -> BaseChatAgentContainer:
|
def _factory() -> BaseChatAgentContainer:
|
||||||
id = AgentInstantiationContext.current_agent_id()
|
id = AgentInstantiationContext.current_agent_id()
|
||||||
@ -68,7 +76,13 @@ class BaseGroupChat(Team, ABC):
|
|||||||
|
|
||||||
return _factory
|
return _factory
|
||||||
|
|
||||||
async def run(self, task: str, *, termination_condition: TerminationCondition | None = None) -> TaskResult:
|
async def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
*,
|
||||||
|
cancellation_token: CancellationToken | None = None,
|
||||||
|
termination_condition: TerminationCondition | None = None,
|
||||||
|
) -> TaskResult:
|
||||||
"""Run the team and return the result."""
|
"""Run the team and return the result."""
|
||||||
# Create intervention handler for termination.
|
# Create intervention handler for termination.
|
||||||
|
|
||||||
@ -85,7 +99,7 @@ class BaseGroupChat(Team, ABC):
|
|||||||
participant_topic_types: List[str] = []
|
participant_topic_types: List[str] = []
|
||||||
participant_descriptions: List[str] = []
|
participant_descriptions: List[str] = []
|
||||||
for participant in self._participants:
|
for participant in self._participants:
|
||||||
if isinstance(participant, BaseToolUseChatAgent):
|
if isinstance(participant, ToolUseChatAgent):
|
||||||
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
|
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
|
||||||
# Register the tool agent.
|
# Register the tool agent.
|
||||||
tool_agent_type = await ToolAgent.register(
|
tool_agent_type = await ToolAgent.register(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
from ...base import BaseChatAgent, TerminationCondition
|
from ...base import ChatAgent, TerminationCondition
|
||||||
from .._events import ContentPublishEvent
|
from .._events import ContentPublishEvent
|
||||||
from ._base_group_chat import BaseGroupChat
|
from ._base_group_chat import BaseGroupChat
|
||||||
from ._base_group_chat_manager import BaseGroupChatManager
|
from ._base_group_chat_manager import BaseGroupChatManager
|
||||||
@ -73,7 +73,7 @@ class RoundRobinGroupChat(BaseGroupChat):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, participants: List[BaseChatAgent]):
|
def __init__(self, participants: List[ChatAgent]):
|
||||||
super().__init__(participants, group_chat_manager_class=RoundRobinGroupChatManager)
|
super().__init__(participants, group_chat_manager_class=RoundRobinGroupChatManager)
|
||||||
|
|
||||||
def _create_group_chat_manager_factory(
|
def _create_group_chat_manager_factory(
|
||||||
|
@ -5,7 +5,7 @@ from typing import Callable, Dict, List
|
|||||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||||
|
|
||||||
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||||
from ...base import BaseChatAgent, TerminationCondition
|
from ...base import ChatAgent, TerminationCondition
|
||||||
from ...messages import MultiModalMessage, StopMessage, TextMessage
|
from ...messages import MultiModalMessage, StopMessage, TextMessage
|
||||||
from .._events import ContentPublishEvent, SelectSpeakerEvent
|
from .._events import ContentPublishEvent, SelectSpeakerEvent
|
||||||
from ._base_group_chat import BaseGroupChat
|
from ._base_group_chat import BaseGroupChat
|
||||||
@ -178,7 +178,7 @@ class SelectorGroupChat(BaseGroupChat):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
participants: List[BaseChatAgent],
|
participants: List[ChatAgent],
|
||||||
model_client: ChatCompletionClient,
|
model_client: ChatCompletionClient,
|
||||||
*,
|
*,
|
||||||
selector_prompt: str = """You are in a role play game. The following roles are available:
|
selector_prompt: str = """You are in a role play game. The following roles are available:
|
||||||
|
@ -7,17 +7,17 @@ from typing import Any, AsyncGenerator, List, Sequence
|
|||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||||
from autogen_agentchat.agents import (
|
from autogen_agentchat.agents import (
|
||||||
|
BaseChatAgent,
|
||||||
CodeExecutorAgent,
|
CodeExecutorAgent,
|
||||||
CodingAssistantAgent,
|
CodingAssistantAgent,
|
||||||
ToolUseAssistantAgent,
|
ToolUseAssistantAgent,
|
||||||
)
|
)
|
||||||
from autogen_agentchat.base import BaseChatAgent
|
|
||||||
from autogen_agentchat.logging import FileLogHandler
|
from autogen_agentchat.logging import FileLogHandler
|
||||||
from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage
|
from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage
|
||||||
|
from autogen_agentchat.task import StopMessageTermination
|
||||||
from autogen_agentchat.teams import (
|
from autogen_agentchat.teams import (
|
||||||
RoundRobinGroupChat,
|
RoundRobinGroupChat,
|
||||||
SelectorGroupChat,
|
SelectorGroupChat,
|
||||||
StopMessageTermination,
|
|
||||||
)
|
)
|
||||||
from autogen_core.base import CancellationToken
|
from autogen_core.base import CancellationToken
|
||||||
from autogen_core.components import FunctionCall
|
from autogen_core.components import FunctionCall
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
from autogen_agentchat.messages import StopMessage, TextMessage
|
||||||
from autogen_agentchat.teams import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -4,13 +4,7 @@ from typing import Any, AsyncGenerator, List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
from autogen_agentchat.agents import ToolUseAssistantAgent
|
||||||
from autogen_agentchat.messages import (
|
from autogen_core.components.models import OpenAIChatCompletionClient
|
||||||
TextMessage,
|
|
||||||
ToolCallMessage,
|
|
||||||
ToolCallResultMessage,
|
|
||||||
)
|
|
||||||
from autogen_core.base import CancellationToken
|
|
||||||
from autogen_core.components.models import FunctionExecutionResult, OpenAIChatCompletionClient
|
|
||||||
from autogen_core.components.tools import FunctionTool
|
from autogen_core.components.tools import FunctionTool
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
@ -63,8 +57,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||||||
id="1",
|
id="1",
|
||||||
type="function",
|
type="function",
|
||||||
function=Function(
|
function=Function(
|
||||||
name="pass",
|
name="_pass_function",
|
||||||
arguments=json.dumps({"input": "pass"}),
|
arguments=json.dumps({"input": "task"}),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -107,14 +101,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
|||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
registered_tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
registered_tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||||
)
|
)
|
||||||
response = await tool_use_agent.on_messages(
|
result = await tool_use_agent.run("task")
|
||||||
messages=[TextMessage(content="Test", source="user")], cancellation_token=CancellationToken()
|
assert len(result.messages) == 3
|
||||||
)
|
# assert isinstance(result.messages[1], ToolCallMessage)
|
||||||
assert isinstance(response, ToolCallMessage)
|
# assert isinstance(result.messages[2], TextMessage)
|
||||||
tool_call_results = [FunctionExecutionResult(content="", call_id=call.id) for call in response.content]
|
|
||||||
|
|
||||||
response = await tool_use_agent.on_messages(
|
|
||||||
messages=[ToolCallResultMessage(content=tool_call_results, source="test")],
|
|
||||||
cancellation_token=CancellationToken(),
|
|
||||||
)
|
|
||||||
assert isinstance(response, TextMessage)
|
|
||||||
|
@ -23,7 +23,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
||||||
"from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination\n",
|
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||||
"from autogen_core.components.tools import FunctionTool\n",
|
"from autogen_core.components.tools import FunctionTool\n",
|
||||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||||
]
|
]
|
||||||
|
@ -23,7 +23,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
||||||
"from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination\n",
|
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||||
"from autogen_core.components.tools import FunctionTool\n",
|
"from autogen_core.components.tools import FunctionTool\n",
|
||||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||||
]
|
]
|
||||||
|
@ -18,7 +18,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from autogen_agentchat.agents import CodingAssistantAgent\n",
|
"from autogen_agentchat.agents import CodingAssistantAgent\n",
|
||||||
"from autogen_agentchat.teams import RoundRobinGroupChat, StopMessageTermination\n",
|
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -58,7 +58,8 @@
|
|||||||
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
||||||
"from autogen_agentchat.agents import ToolUseAssistantAgent\n",
|
"from autogen_agentchat.agents import ToolUseAssistantAgent\n",
|
||||||
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
||||||
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat\n",
|
"from autogen_agentchat.task import MaxMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||||
"from autogen_core.components.tools import FunctionTool\n",
|
"from autogen_core.components.tools import FunctionTool\n",
|
||||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -126,7 +127,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.5"
|
"version": "3.12.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -45,10 +45,9 @@
|
|||||||
"import logging\n",
|
"import logging\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
||||||
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
"from autogen_agentchat.agents import ToolUseAssistantAgent\n",
|
||||||
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
||||||
"from autogen_agentchat.messages import TextMessage\n",
|
"from autogen_agentchat.messages import TextMessage\n",
|
||||||
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat, SelectorGroupChat\n",
|
|
||||||
"from autogen_core.base import CancellationToken\n",
|
"from autogen_core.base import CancellationToken\n",
|
||||||
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
|
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
|
||||||
"from autogen_core.components.tools import FunctionTool\n",
|
"from autogen_core.components.tools import FunctionTool\n",
|
||||||
@ -251,7 +250,7 @@
|
|||||||
"import asyncio\n",
|
"import asyncio\n",
|
||||||
"from typing import Sequence\n",
|
"from typing import Sequence\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from autogen_agentchat.base import BaseChatAgent\n",
|
"from autogen_agentchat.agents import BaseChatAgent\n",
|
||||||
"from autogen_agentchat.messages import (\n",
|
"from autogen_agentchat.messages import (\n",
|
||||||
" ChatMessage,\n",
|
" ChatMessage,\n",
|
||||||
" StopMessage,\n",
|
" StopMessage,\n",
|
||||||
@ -313,7 +312,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.5"
|
"version": "3.12.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -98,7 +98,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "agnext",
|
"display_name": ".venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -112,7 +112,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.9"
|
"version": "3.12.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -41,12 +41,13 @@
|
|||||||
"from typing import Sequence\n",
|
"from typing import Sequence\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from autogen_agentchat.agents import (\n",
|
"from autogen_agentchat.agents import (\n",
|
||||||
|
" BaseChatAgent,\n",
|
||||||
" CodingAssistantAgent,\n",
|
" CodingAssistantAgent,\n",
|
||||||
" ToolUseAssistantAgent,\n",
|
" ToolUseAssistantAgent,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from autogen_agentchat.base import BaseChatAgent\n",
|
|
||||||
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
|
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
|
||||||
"from autogen_agentchat.teams import SelectorGroupChat, StopMessageTermination\n",
|
"from autogen_agentchat.task import StopMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import SelectorGroupChat\n",
|
||||||
"from autogen_core.base import CancellationToken\n",
|
"from autogen_core.base import CancellationToken\n",
|
||||||
"from autogen_core.components.tools import FunctionTool\n",
|
"from autogen_core.components.tools import FunctionTool\n",
|
||||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||||
@ -268,7 +269,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.5"
|
"version": "3.12.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -28,7 +28,8 @@
|
|||||||
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
||||||
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
"from autogen_agentchat.agents import CodingAssistantAgent, ToolUseAssistantAgent\n",
|
||||||
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
||||||
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat, SelectorGroupChat\n",
|
"from autogen_agentchat.task import MaxMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat\n",
|
||||||
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
|
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
|
||||||
"from autogen_core.components.tools import FunctionTool\n",
|
"from autogen_core.components.tools import FunctionTool\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -208,7 +209,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.5"
|
"version": "3.12.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -37,7 +37,8 @@
|
|||||||
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
"from autogen_agentchat import EVENT_LOGGER_NAME\n",
|
||||||
"from autogen_agentchat.agents import CodingAssistantAgent\n",
|
"from autogen_agentchat.agents import CodingAssistantAgent\n",
|
||||||
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
"from autogen_agentchat.logging import ConsoleLogHandler\n",
|
||||||
"from autogen_agentchat.teams import MaxMessageTermination, RoundRobinGroupChat, StopMessageTermination\n",
|
"from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||||
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
|
"from autogen_core.components.models import OpenAIChatCompletionClient\n",
|
||||||
"\n",
|
"\n",
|
||||||
"logger = logging.getLogger(EVENT_LOGGER_NAME)\n",
|
"logger = logging.getLogger(EVENT_LOGGER_NAME)\n",
|
||||||
@ -198,7 +199,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.5"
|
"version": "3.12.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user