2024-10-22 13:27:06 -07:00
|
|
|
import asyncio
|
|
|
|
import json
|
2024-10-29 08:04:14 -07:00
|
|
|
import logging
|
2024-10-22 13:27:06 -07:00
|
|
|
from typing import Any, AsyncGenerator, List
|
|
|
|
|
|
|
|
import pytest
|
2024-10-29 08:04:14 -07:00
|
|
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
|
|
|
from autogen_agentchat.agents import AssistantAgent, Handoff
|
|
|
|
from autogen_agentchat.logging import FileLogHandler
|
|
|
|
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
|
|
|
|
from autogen_core.base import CancellationToken
|
2024-10-22 13:27:06 -07:00
|
|
|
from autogen_core.components.tools import FunctionTool
|
2024-10-25 23:17:06 -07:00
|
|
|
from autogen_ext.models import OpenAIChatCompletionClient
|
2024-10-22 13:27:06 -07:00
|
|
|
from openai.resources.chat.completions import AsyncCompletions
|
|
|
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
|
|
|
|
from openai.types.completion_usage import CompletionUsage
|
|
|
|
|
2024-10-29 08:04:14 -07:00
|
|
|
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
|
|
|
|
|
2024-10-22 13:27:06 -07:00
|
|
|
|
|
|
|
class _MockChatCompletion:
|
|
|
|
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
|
|
|
self._saved_chat_completions = chat_completions
|
|
|
|
self._curr_index = 0
|
|
|
|
|
|
|
|
async def mock_create(
|
|
|
|
self, *args: Any, **kwargs: Any
|
|
|
|
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
completion = self._saved_chat_completions[self._curr_index]
|
|
|
|
self._curr_index += 1
|
|
|
|
return completion
|
|
|
|
|
|
|
|
|
|
|
|
def _pass_function(input: str) -> str:
|
|
|
|
return "pass"
|
|
|
|
|
|
|
|
|
|
|
|
async def _fail_function(input: str) -> str:
|
|
|
|
return "fail"
|
|
|
|
|
|
|
|
|
|
|
|
async def _echo_function(input: str) -> str:
|
|
|
|
return input
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-10-25 23:17:06 -07:00
|
|
|
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
2024-10-22 13:27:06 -07:00
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="tool_calls",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(
|
|
|
|
content=None,
|
|
|
|
tool_calls=[
|
|
|
|
ChatCompletionMessageToolCall(
|
|
|
|
id="1",
|
|
|
|
type="function",
|
|
|
|
function=Function(
|
2024-10-24 05:36:33 -07:00
|
|
|
name="_pass_function",
|
|
|
|
arguments=json.dumps({"input": "task"}),
|
2024-10-22 13:27:06 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
role="assistant",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
2024-10-25 23:17:06 -07:00
|
|
|
tool_use_agent = AssistantAgent(
|
2024-10-22 13:27:06 -07:00
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
2024-10-25 23:17:06 -07:00
|
|
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
2024-10-22 13:27:06 -07:00
|
|
|
)
|
2024-10-24 05:36:33 -07:00
|
|
|
result = await tool_use_agent.run("task")
|
|
|
|
assert len(result.messages) == 3
|
2024-10-25 10:57:04 -07:00
|
|
|
assert isinstance(result.messages[0], TextMessage)
|
|
|
|
assert isinstance(result.messages[1], TextMessage)
|
|
|
|
assert isinstance(result.messages[2], StopMessage)
|
2024-10-29 08:04:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
handoff = Handoff(target="agent2")
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="tool_calls",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(
|
|
|
|
content=None,
|
|
|
|
tool_calls=[
|
|
|
|
ChatCompletionMessageToolCall(
|
|
|
|
id="1",
|
|
|
|
type="function",
|
|
|
|
function=Function(
|
|
|
|
name=handoff.name,
|
|
|
|
arguments=json.dumps({}),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
role="assistant",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
tool_use_agent = AssistantAgent(
|
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
|
|
|
handoffs=[handoff],
|
|
|
|
)
|
|
|
|
response = await tool_use_agent.on_messages(
|
|
|
|
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
|
|
|
|
)
|
|
|
|
assert isinstance(response, HandoffMessage)
|
|
|
|
assert response.target == "agent2"
|