2024-10-22 13:27:06 -07:00
|
|
|
import asyncio
|
|
|
|
import json
|
2024-10-29 08:04:14 -07:00
|
|
|
import logging
|
2024-10-22 13:27:06 -07:00
|
|
|
from typing import Any, AsyncGenerator, List
|
|
|
|
|
|
|
|
import pytest
|
2024-10-29 08:04:14 -07:00
|
|
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
2024-12-03 14:34:55 -08:00
|
|
|
from autogen_agentchat.agents import AssistantAgent
|
|
|
|
from autogen_agentchat.base import Handoff, TaskResult
|
2024-11-07 21:38:41 -08:00
|
|
|
from autogen_agentchat.messages import (
|
2024-12-15 11:18:17 +05:30
|
|
|
ChatMessage,
|
2024-11-07 21:38:41 -08:00
|
|
|
HandoffMessage,
|
Memory Interface in AgentChat (#4438)
* initial base memroy impl
* update, add example with chromadb
* include mimetype consideration
* add transform method
* update to address feedback, will update after 4681 is merged
* update memory impl,
* remove chroma db, typing fixes
* format, add test
* update uv lock
* update docs
* format updates
* update notebook
* add memoryqueryevent message, yield message for observability.
* minor fixes, make score optional/none
* Update python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
* update tests to improve cov
* refactor, move memory to core.
* format fixxes
* format updates
* format updates
* fix azure notebook import, other fixes
* update notebook, support str query in Memory protocol
* update test
* update cells
* add specific extensible return types to memory query and update_context
---------
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-01-13 23:06:13 -08:00
|
|
|
MemoryQueryEvent,
|
2024-11-07 21:38:41 -08:00
|
|
|
MultiModalMessage,
|
|
|
|
TextMessage,
|
2024-12-18 14:09:19 -08:00
|
|
|
ToolCallExecutionEvent,
|
|
|
|
ToolCallRequestEvent,
|
2024-12-20 00:23:18 -05:00
|
|
|
ToolCallSummaryMessage,
|
2024-11-07 21:38:41 -08:00
|
|
|
)
|
2024-12-03 17:00:44 -08:00
|
|
|
from autogen_core import Image
|
Memory Interface in AgentChat (#4438)
* initial base memroy impl
* update, add example with chromadb
* include mimetype consideration
* add transform method
* update to address feedback, will update after 4681 is merged
* update memory impl,
* remove chroma db, typing fixes
* format, add test
* update uv lock
* update docs
* format updates
* update notebook
* add memoryqueryevent message, yield message for observability.
* minor fixes, make score optional/none
* Update python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
* update tests to improve cov
* refactor, move memory to core.
* format fixxes
* format updates
* format updates
* fix azure notebook import, other fixes
* update notebook, support str query in Memory protocol
* update test
* update cells
* add specific extensible return types to memory query and update_context
---------
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-01-13 23:06:13 -08:00
|
|
|
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
|
2024-12-29 07:50:54 +01:00
|
|
|
from autogen_core.model_context import BufferedChatCompletionContext
|
|
|
|
from autogen_core.models import LLMMessage
|
2024-12-30 15:09:21 -05:00
|
|
|
from autogen_core.models._model_client import ModelFamily
|
2024-12-09 21:39:07 -05:00
|
|
|
from autogen_core.tools import FunctionTool
|
2024-12-10 13:18:09 +10:00
|
|
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
2024-10-22 13:27:06 -07:00
|
|
|
from openai.resources.chat.completions import AsyncCompletions
|
|
|
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
2024-12-15 11:18:17 +05:30
|
|
|
from openai.types.chat.chat_completion_message_tool_call import (
|
|
|
|
ChatCompletionMessageToolCall,
|
|
|
|
Function,
|
|
|
|
)
|
2024-10-22 13:27:06 -07:00
|
|
|
from openai.types.completion_usage import CompletionUsage
|
2024-12-03 14:45:10 -08:00
|
|
|
from utils import FileLogHandler
|
2024-10-22 13:27:06 -07:00
|
|
|
|
2024-10-29 08:04:14 -07:00
|
|
|
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
|
|
|
|
|
2024-10-22 13:27:06 -07:00
|
|
|
|
|
|
|
class _MockChatCompletion:
|
|
|
|
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
|
|
|
self._saved_chat_completions = chat_completions
|
2024-12-15 11:18:17 +05:30
|
|
|
self.curr_index = 0
|
2024-12-29 07:50:54 +01:00
|
|
|
self.calls: List[List[LLMMessage]] = []
|
2024-10-22 13:27:06 -07:00
|
|
|
|
|
|
|
async def mock_create(
|
|
|
|
self, *args: Any, **kwargs: Any
|
|
|
|
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
2024-12-29 07:50:54 +01:00
|
|
|
self.calls.append(kwargs["messages"]) # Save the call
|
2024-10-22 13:27:06 -07:00
|
|
|
await asyncio.sleep(0.1)
|
2024-12-15 11:18:17 +05:30
|
|
|
completion = self._saved_chat_completions[self.curr_index]
|
|
|
|
self.curr_index += 1
|
2024-10-22 13:27:06 -07:00
|
|
|
return completion
|
|
|
|
|
|
|
|
|
|
|
|
def _pass_function(input: str) -> str:
|
|
|
|
return "pass"
|
|
|
|
|
|
|
|
|
|
|
|
async def _fail_function(input: str) -> str:
|
|
|
|
return "fail"
|
|
|
|
|
|
|
|
|
|
|
|
async def _echo_function(input: str) -> str:
|
|
|
|
return input
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-10-25 23:17:06 -07:00
|
|
|
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
2024-10-22 13:27:06 -07:00
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="tool_calls",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(
|
|
|
|
content=None,
|
|
|
|
tool_calls=[
|
|
|
|
ChatCompletionMessageToolCall(
|
|
|
|
id="1",
|
|
|
|
type="function",
|
|
|
|
function=Function(
|
2024-10-24 05:36:33 -07:00
|
|
|
name="_pass_function",
|
|
|
|
arguments=json.dumps({"input": "task"}),
|
2024-10-22 13:27:06 -07:00
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
role="assistant",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
2024-11-01 13:20:25 -07:00
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
2024-10-22 13:27:06 -07:00
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
2024-12-15 11:18:17 +05:30
|
|
|
Choice(
|
|
|
|
finish_reason="stop",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(content="pass", role="assistant"),
|
|
|
|
)
|
2024-10-22 13:27:06 -07:00
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
2024-11-01 13:20:25 -07:00
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
2024-10-22 13:27:06 -07:00
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
2024-12-15 11:18:17 +05:30
|
|
|
finish_reason="stop",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
|
2024-10-22 13:27:06 -07:00
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
2024-11-01 13:20:25 -07:00
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
2024-10-22 13:27:06 -07:00
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
2024-12-04 16:14:41 -08:00
|
|
|
agent = AssistantAgent(
|
2024-10-22 13:27:06 -07:00
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
2024-12-15 11:18:17 +05:30
|
|
|
tools=[
|
|
|
|
_pass_function,
|
|
|
|
_fail_function,
|
|
|
|
FunctionTool(_echo_function, description="Echo"),
|
|
|
|
],
|
2024-10-22 13:27:06 -07:00
|
|
|
)
|
2024-12-04 16:14:41 -08:00
|
|
|
result = await agent.run(task="task")
|
2024-12-09 19:03:31 -08:00
|
|
|
|
|
|
|
assert len(result.messages) == 4
|
|
|
|
assert isinstance(result.messages[0], TextMessage)
|
|
|
|
assert result.messages[0].models_usage is None
|
2024-12-18 14:09:19 -08:00
|
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
2024-12-09 19:03:31 -08:00
|
|
|
assert result.messages[1].models_usage is not None
|
|
|
|
assert result.messages[1].models_usage.completion_tokens == 5
|
|
|
|
assert result.messages[1].models_usage.prompt_tokens == 10
|
2024-12-18 14:09:19 -08:00
|
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
2024-12-09 19:03:31 -08:00
|
|
|
assert result.messages[2].models_usage is None
|
2024-12-20 00:23:18 -05:00
|
|
|
assert isinstance(result.messages[3], ToolCallSummaryMessage)
|
2024-12-09 19:03:31 -08:00
|
|
|
assert result.messages[3].content == "pass"
|
|
|
|
assert result.messages[3].models_usage is None
|
|
|
|
|
|
|
|
# Test streaming.
|
2024-12-15 11:18:17 +05:30
|
|
|
mock.curr_index = 0 # Reset the mock
|
2024-12-09 19:03:31 -08:00
|
|
|
index = 0
|
|
|
|
async for message in agent.run_stream(task="task"):
|
|
|
|
if isinstance(message, TaskResult):
|
|
|
|
assert message == result
|
|
|
|
else:
|
|
|
|
assert message == result.messages[index]
|
2024-12-15 11:18:17 +05:30
|
|
|
index += 1
|
2024-12-09 19:03:31 -08:00
|
|
|
|
|
|
|
# Test state saving and loading.
|
|
|
|
state = await agent.save_state()
|
|
|
|
agent2 = AssistantAgent(
|
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
|
|
|
)
|
|
|
|
await agent2.load_state(state)
|
|
|
|
state2 = await agent2.save_state()
|
|
|
|
assert state == state2
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="tool_calls",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(
|
|
|
|
content=None,
|
|
|
|
tool_calls=[
|
|
|
|
ChatCompletionMessageToolCall(
|
|
|
|
id="1",
|
|
|
|
type="function",
|
|
|
|
function=Function(
|
|
|
|
name="_pass_function",
|
|
|
|
arguments=json.dumps({"input": "task"}),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
role="assistant",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
agent = AssistantAgent(
|
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
|
|
|
reflect_on_tool_use=True,
|
|
|
|
)
|
|
|
|
result = await agent.run(task="task")
|
|
|
|
|
2024-10-30 10:27:57 -07:00
|
|
|
assert len(result.messages) == 4
|
2024-10-25 10:57:04 -07:00
|
|
|
assert isinstance(result.messages[0], TextMessage)
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[0].models_usage is None
|
2024-12-18 14:09:19 -08:00
|
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[1].models_usage is not None
|
|
|
|
assert result.messages[1].models_usage.completion_tokens == 5
|
|
|
|
assert result.messages[1].models_usage.prompt_tokens == 10
|
2024-12-18 14:09:19 -08:00
|
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[2].models_usage is None
|
2024-10-30 10:27:57 -07:00
|
|
|
assert isinstance(result.messages[3], TextMessage)
|
2024-12-09 19:03:31 -08:00
|
|
|
assert result.messages[3].content == "Hello"
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[3].models_usage is not None
|
|
|
|
assert result.messages[3].models_usage.completion_tokens == 5
|
|
|
|
assert result.messages[3].models_usage.prompt_tokens == 10
|
2024-10-29 08:04:14 -07:00
|
|
|
|
2024-11-01 04:12:43 -07:00
|
|
|
# Test streaming.
|
2024-12-15 11:18:17 +05:30
|
|
|
mock.curr_index = 0 # pyright: ignore
|
2024-11-01 04:12:43 -07:00
|
|
|
index = 0
|
2024-12-04 16:14:41 -08:00
|
|
|
async for message in agent.run_stream(task="task"):
|
2024-11-01 04:12:43 -07:00
|
|
|
if isinstance(message, TaskResult):
|
|
|
|
assert message == result
|
|
|
|
else:
|
|
|
|
assert message == result.messages[index]
|
|
|
|
index += 1
|
|
|
|
|
2024-12-04 16:14:41 -08:00
|
|
|
# Test state saving and loading.
|
|
|
|
state = await agent.save_state()
|
|
|
|
agent2 = AssistantAgent(
|
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
2024-12-15 11:18:17 +05:30
|
|
|
tools=[
|
|
|
|
_pass_function,
|
|
|
|
_fail_function,
|
|
|
|
FunctionTool(_echo_function, description="Echo"),
|
|
|
|
],
|
2024-12-04 16:14:41 -08:00
|
|
|
)
|
|
|
|
await agent2.load_state(state)
|
|
|
|
state2 = await agent2.save_state()
|
|
|
|
assert state == state2
|
|
|
|
|
2024-10-29 08:04:14 -07:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
handoff = Handoff(target="agent2")
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="tool_calls",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(
|
|
|
|
content=None,
|
|
|
|
tool_calls=[
|
|
|
|
ChatCompletionMessageToolCall(
|
|
|
|
id="1",
|
|
|
|
type="function",
|
|
|
|
function=Function(
|
|
|
|
name=handoff.name,
|
|
|
|
arguments=json.dumps({}),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
role="assistant",
|
|
|
|
),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
2024-11-01 13:20:25 -07:00
|
|
|
usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85),
|
2024-10-29 08:04:14 -07:00
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
tool_use_agent = AssistantAgent(
|
|
|
|
"tool_use_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
2024-12-15 11:18:17 +05:30
|
|
|
tools=[
|
|
|
|
_pass_function,
|
|
|
|
_fail_function,
|
|
|
|
FunctionTool(_echo_function, description="Echo"),
|
|
|
|
],
|
2024-10-29 08:04:14 -07:00
|
|
|
handoffs=[handoff],
|
|
|
|
)
|
2024-10-30 05:32:11 -07:00
|
|
|
assert HandoffMessage in tool_use_agent.produced_message_types
|
2024-11-07 16:00:35 -08:00
|
|
|
result = await tool_use_agent.run(task="task")
|
2024-11-01 04:12:43 -07:00
|
|
|
assert len(result.messages) == 4
|
|
|
|
assert isinstance(result.messages[0], TextMessage)
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[0].models_usage is None
|
2024-12-18 14:09:19 -08:00
|
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[1].models_usage is not None
|
|
|
|
assert result.messages[1].models_usage.completion_tokens == 43
|
|
|
|
assert result.messages[1].models_usage.prompt_tokens == 42
|
2024-12-18 14:09:19 -08:00
|
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[2].models_usage is None
|
2024-11-01 04:12:43 -07:00
|
|
|
assert isinstance(result.messages[3], HandoffMessage)
|
2024-11-01 13:20:25 -07:00
|
|
|
assert result.messages[3].content == handoff.message
|
|
|
|
assert result.messages[3].target == handoff.target
|
2024-11-04 09:25:53 -08:00
|
|
|
assert result.messages[3].models_usage is None
|
2024-11-01 04:12:43 -07:00
|
|
|
|
|
|
|
# Test streaming.
|
2024-12-15 11:18:17 +05:30
|
|
|
mock.curr_index = 0 # pyright: ignore
|
2024-11-01 04:12:43 -07:00
|
|
|
index = 0
|
2024-11-07 16:00:35 -08:00
|
|
|
async for message in tool_use_agent.run_stream(task="task"):
|
2024-11-01 04:12:43 -07:00
|
|
|
if isinstance(message, TaskResult):
|
|
|
|
assert message == result
|
|
|
|
else:
|
|
|
|
assert message == result.messages[index]
|
|
|
|
index += 1
|
2024-11-07 21:38:41 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
2024-12-15 11:18:17 +05:30
|
|
|
Choice(
|
|
|
|
finish_reason="stop",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(content="Hello", role="assistant"),
|
|
|
|
)
|
2024-11-07 21:38:41 -08:00
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
2024-12-15 11:18:17 +05:30
|
|
|
agent = AssistantAgent(
|
|
|
|
name="assistant",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
)
|
2024-11-07 21:38:41 -08:00
|
|
|
# Generate a random base64 image.
|
|
|
|
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
|
|
|
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
|
|
|
|
assert len(result.messages) == 2
|
2024-11-27 10:45:51 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_invalid_model_capabilities() -> None:
|
|
|
|
model = "random-model"
|
|
|
|
model_client = OpenAIChatCompletionClient(
|
2024-12-15 11:18:17 +05:30
|
|
|
model=model,
|
|
|
|
api_key="",
|
2024-12-30 15:09:21 -05:00
|
|
|
model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
|
2024-11-27 10:45:51 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
agent = AssistantAgent(
|
|
|
|
name="assistant",
|
|
|
|
model_client=model_client,
|
2024-12-15 11:18:17 +05:30
|
|
|
tools=[
|
|
|
|
_pass_function,
|
|
|
|
_fail_function,
|
|
|
|
FunctionTool(_echo_function, description="Echo"),
|
|
|
|
],
|
2024-11-27 10:45:51 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"])
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
agent = AssistantAgent(name="assistant", model_client=model_client)
|
|
|
|
# Generate a random base64 image.
|
|
|
|
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
|
|
|
await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
|
2024-12-15 11:18:17 +05:30
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="stop",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(content="Response to message 1", role="assistant"),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
agent = AssistantAgent(
|
|
|
|
"test_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create a list of chat messages
|
|
|
|
messages: List[ChatMessage] = [
|
|
|
|
TextMessage(content="Message 1", source="user"),
|
|
|
|
TextMessage(content="Message 2", source="user"),
|
|
|
|
]
|
|
|
|
|
|
|
|
# Test run method with list of messages
|
|
|
|
result = await agent.run(task=messages)
|
|
|
|
assert len(result.messages) == 3 # 2 input messages + 1 response message
|
|
|
|
assert isinstance(result.messages[0], TextMessage)
|
|
|
|
assert result.messages[0].content == "Message 1"
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert isinstance(result.messages[1], TextMessage)
|
|
|
|
assert result.messages[1].content == "Message 2"
|
|
|
|
assert result.messages[1].source == "user"
|
|
|
|
assert isinstance(result.messages[2], TextMessage)
|
|
|
|
assert result.messages[2].content == "Response to message 1"
|
|
|
|
assert result.messages[2].source == "test_agent"
|
|
|
|
assert result.messages[2].models_usage is not None
|
|
|
|
assert result.messages[2].models_usage.completion_tokens == 5
|
|
|
|
assert result.messages[2].models_usage.prompt_tokens == 10
|
|
|
|
|
|
|
|
# Test run_stream method with list of messages
|
|
|
|
mock.curr_index = 0 # Reset mock index using public attribute
|
|
|
|
index = 0
|
|
|
|
async for message in agent.run_stream(task=messages):
|
|
|
|
if isinstance(message, TaskResult):
|
|
|
|
assert message == result
|
|
|
|
else:
|
|
|
|
assert message == result.messages[index]
|
|
|
|
index += 1
|
2024-12-29 07:50:54 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="stop",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
model_context = BufferedChatCompletionContext(buffer_size=2)
|
|
|
|
agent = AssistantAgent(
|
|
|
|
"test_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
model_context=model_context,
|
|
|
|
)
|
|
|
|
|
|
|
|
messages = [
|
|
|
|
TextMessage(content="Message 1", source="user"),
|
|
|
|
TextMessage(content="Message 2", source="user"),
|
|
|
|
TextMessage(content="Message 3", source="user"),
|
|
|
|
]
|
|
|
|
await agent.run(task=messages)
|
|
|
|
|
|
|
|
# Check if the mock client is called with only the last two messages.
|
|
|
|
assert len(mock.calls) == 1
|
Memory Interface in AgentChat (#4438)
* initial base memroy impl
* update, add example with chromadb
* include mimetype consideration
* add transform method
* update to address feedback, will update after 4681 is merged
* update memory impl,
* remove chroma db, typing fixes
* format, add test
* update uv lock
* update docs
* format updates
* update notebook
* add memoryqueryevent message, yield message for observability.
* minor fixes, make score optional/none
* Update python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
* update tests to improve cov
* refactor, move memory to core.
* format fixxes
* format updates
* format updates
* fix azure notebook import, other fixes
* update notebook, support str query in Memory protocol
* update test
* update cells
* add specific extensible return types to memory query and update_context
---------
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-01-13 23:06:13 -08:00
|
|
|
# 2 message from the context + 1 system message
|
|
|
|
assert len(mock.calls[0]) == 3
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id1",
|
|
|
|
choices=[
|
|
|
|
Choice(
|
|
|
|
finish_reason="stop",
|
|
|
|
index=0,
|
|
|
|
message=ChatCompletionMessage(content="Hello", role="assistant"),
|
|
|
|
)
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
|
|
|
|
# Test basic memory properties and empty context
|
|
|
|
memory = ListMemory(name="test_memory")
|
|
|
|
assert memory.name == "test_memory"
|
|
|
|
|
|
|
|
empty_context = BufferedChatCompletionContext(buffer_size=2)
|
|
|
|
empty_results = await memory.update_context(empty_context)
|
|
|
|
assert len(empty_results.memories.results) == 0
|
|
|
|
|
|
|
|
# Test various content types
|
|
|
|
memory = ListMemory()
|
|
|
|
await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT))
|
|
|
|
await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON))
|
|
|
|
await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE))
|
|
|
|
|
|
|
|
# Test query functionality
|
|
|
|
query_result = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
|
|
|
|
assert isinstance(query_result, MemoryQueryResult)
|
|
|
|
# Should have all three memories we added
|
|
|
|
assert len(query_result.results) == 3
|
|
|
|
|
|
|
|
# Test clear and cleanup
|
|
|
|
await memory.clear()
|
|
|
|
empty_query = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT))
|
|
|
|
assert len(empty_query.results) == 0
|
|
|
|
await memory.close() # Should not raise
|
|
|
|
|
|
|
|
# Test invalid memory type
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
AssistantAgent(
|
|
|
|
"test_agent",
|
|
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
|
|
memory="invalid", # type: ignore
|
|
|
|
)
|
|
|
|
|
|
|
|
# Test with agent
|
|
|
|
memory2 = ListMemory()
|
|
|
|
await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT))
|
|
|
|
|
|
|
|
agent = AssistantAgent(
|
|
|
|
"test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2]
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await agent.run(task="test task")
|
|
|
|
assert len(result.messages) > 0
|
|
|
|
memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None)
|
|
|
|
assert memory_event is not None
|
|
|
|
assert len(memory_event.content) > 0
|
|
|
|
assert isinstance(memory_event.content[0], MemoryContent)
|
|
|
|
|
|
|
|
# Test memory protocol
|
|
|
|
class BadMemory:
|
|
|
|
pass
|
|
|
|
|
|
|
|
assert not isinstance(BadMemory(), Memory)
|
|
|
|
assert isinstance(ListMemory(), Memory)
|