mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-01 02:09:58 +00:00
AgentOps Runtime Logging Implementation (#2682)
* add agentops req * track conversable agents with agentops * track tool usage * track message sending * remove record from parent * remove record * simple example * notebook example * remove spacing change * optional dependency * documentation * remove extra import * optional import * record if agentops * if agentops * wrap function auto name * install agentops before notebook test * documentation fixes * notebook metadata * notebook metadata * pre-commit hook changes * doc link fixes * git lfs * autogen tag * bump agentops version * log tool events * notebook fixes * docs * formatting * Updated ecosystem manual * Update notebook for clarity * cleaned up notebook * updated precommit recommendations * Fixed links to screenshots and examples * removed unused files * changed notebook hyperlink * update docusaurus link path * reverted setup.py * change setup again * undo changes * revert conversable agent * removed file not in branch * Updated notebook to look nicer * change letter * revert setup * revert setup again * change ref link * change reflink * remove optional dependency * removed duplicated section * Addressed clarity commetns from howard * minor updates to wording * formatting and pr fixes * added info markdown cell * better docs * notebook * observability docs * pre-commit fixes * example images in notebook * example images in docs * example images in docs * delete agentops ong * doc updates * docs updates * docs updates * use agent as extra_kwarg * add logging tests * pass function properly * create table * dummy function name * log chat completion source name * safe serialize * test fixes * formatting * type checks --------- Co-authored-by: reibs <areibman@gmail.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Howard Gil <howardbgil@gmail.com> Co-authored-by: Alex Reibman <meta.alex.r@gmail.com>
This commit is contained in:
parent
75f0808b5a
commit
85ad929f34
@ -31,7 +31,7 @@ from ..formatting_utils import colored
|
||||
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
|
||||
from ..io.base import IOStream
|
||||
from ..oai.client import ModelClient, OpenAIWrapper
|
||||
from ..runtime_logging import log_event, log_new_agent, logging_enabled
|
||||
from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
|
||||
from .agent import Agent, LLMAgent
|
||||
from .chat import ChatResult, a_initiate_chats, initiate_chats
|
||||
from .utils import consolidate_chat_info, gather_usage_summary
|
||||
@ -1357,9 +1357,7 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# TODO: #1143 handle token limit exceeded error
|
||||
response = llm_client.create(
|
||||
context=messages[-1].pop("context", None),
|
||||
messages=all_messages,
|
||||
cache=cache,
|
||||
context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self
|
||||
)
|
||||
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
|
||||
|
||||
@ -2528,13 +2526,14 @@ class ConversableAgent(LLMAgent):
|
||||
@functools.wraps(func)
|
||||
def _wrapped_func(*args, **kwargs):
|
||||
retval = func(*args, **kwargs)
|
||||
|
||||
log_function_use(self, func, kwargs, retval)
|
||||
return serialize_to_str(retval)
|
||||
|
||||
@load_basemodels_if_needed
|
||||
@functools.wraps(func)
|
||||
async def _a_wrapped_func(*args, **kwargs):
|
||||
retval = await func(*args, **kwargs)
|
||||
log_function_use(self, func, kwargs, retval)
|
||||
return serialize_to_str(retval)
|
||||
|
||||
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import sqlite3
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeVar, Union
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletion
|
||||
if TYPE_CHECKING:
|
||||
from autogen import Agent, ConversableAgent, OpenAIWrapper
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
ConfigItem = Dict[str, Union[str, List[str]]]
|
||||
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
|
||||
|
||||
@ -32,6 +33,7 @@ class BaseLogger(ABC):
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
source: Union[str, Agent],
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
@ -49,9 +51,10 @@ class BaseLogger(ABC):
|
||||
invocation_id (uuid): A unique identifier for the invocation to the OpenAIWrapper.create method call
|
||||
client_id (int): A unique identifier for the underlying OpenAI client instance
|
||||
wrapper_id (int): A unique identifier for the OpenAIWrapper instance
|
||||
request (dict): A dictionary representing the the request or call to the OpenAI client endpoint
|
||||
source (str or Agent): The source/creator of the event as a string name or an Agent instance
|
||||
request (dict): A dictionary representing the request or call to the OpenAI client endpoint
|
||||
response (str or ChatCompletion): The response from OpenAI
|
||||
is_chached (int): 1 if the response was a cache hit, 0 otherwise
|
||||
is_cached (int): 1 if the response was a cache hit, 0 otherwise
|
||||
cost(float): The cost for OpenAI response
|
||||
start_time (str): A string representing the moment the request was initiated
|
||||
"""
|
||||
@ -104,6 +107,18 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
|
||||
"""
|
||||
Log the use of a registered function (could be a tool)
|
||||
|
||||
Args:
|
||||
source (str or Agent): The source/creator of the event as a string name or an Agent instance
|
||||
function (F): The function information
|
||||
args (dict): The function args to log
|
||||
returns (any): The return
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
@ -21,9 +21,21 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
__all__ = ("FileLogger",)
|
||||
|
||||
|
||||
def safe_serialize(obj: Any) -> str:
|
||||
def default(o: Any) -> str:
|
||||
if hasattr(o, "to_json"):
|
||||
return str(o.to_json())
|
||||
else:
|
||||
return f"<<non-serializable: {type(o).__qualname__}>>"
|
||||
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
class FileLogger(BaseLogger):
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
@ -59,6 +71,7 @@ class FileLogger(BaseLogger):
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
source: Union[str, Agent],
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
@ -69,6 +82,11 @@ class FileLogger(BaseLogger):
|
||||
Log a chat completion.
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
source_name = None
|
||||
if isinstance(source, str):
|
||||
source_name = source
|
||||
else:
|
||||
source_name = source.name
|
||||
try:
|
||||
log_data = json.dumps(
|
||||
{
|
||||
@ -82,6 +100,7 @@ class FileLogger(BaseLogger):
|
||||
"start_time": start_time,
|
||||
"end_time": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
"source_name": source_name,
|
||||
}
|
||||
)
|
||||
|
||||
@ -204,6 +223,29 @@ class FileLogger(BaseLogger):
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
|
||||
def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
|
||||
"""
|
||||
Log a registered function(can be a tool) use from an agent or a string source.
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
|
||||
try:
|
||||
log_data = json.dumps(
|
||||
{
|
||||
"source_id": id(source),
|
||||
"source_name": str(source.name) if hasattr(source, "name") else source,
|
||||
"agent_module": source.__module__,
|
||||
"agent_class": source.__class__.__name__,
|
||||
"timestamp": get_current_ts(),
|
||||
"thread_id": thread_id,
|
||||
"input_args": safe_serialize(args),
|
||||
"returns": safe_serialize(returns),
|
||||
}
|
||||
)
|
||||
self.logger.info(log_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[file_logger] Failed to log event {e}")
|
||||
|
||||
def get_connection(self) -> None:
|
||||
"""Method is intentionally left blank because there is no specific connection needed for the FileLogger."""
|
||||
pass
|
||||
|
||||
@ -6,7 +6,7 @@ import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar, Union
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
@ -25,6 +25,18 @@ lock = threading.Lock()
|
||||
|
||||
__all__ = ("SqliteLogger",)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def safe_serialize(obj: Any) -> str:
|
||||
def default(o: Any) -> str:
|
||||
if hasattr(o, "to_json"):
|
||||
return str(o.to_json())
|
||||
else:
|
||||
return f"<<non-serializable: {type(o).__qualname__}>>"
|
||||
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
class SqliteLogger(BaseLogger):
|
||||
schema_version = 1
|
||||
@ -49,6 +61,7 @@ class SqliteLogger(BaseLogger):
|
||||
client_id INTEGER,
|
||||
wrapper_id INTEGER,
|
||||
session_id TEXT,
|
||||
source_name TEXT,
|
||||
request TEXT,
|
||||
response TEXT,
|
||||
is_cached INEGER,
|
||||
@ -118,6 +131,18 @@ class SqliteLogger(BaseLogger):
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS function_calls (
|
||||
source_id INTEGER,
|
||||
source_name TEXT,
|
||||
function_name TEXT,
|
||||
args TEXT DEFAULT NULL,
|
||||
returns TEXT DEFAULT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
self._run_query(query=query)
|
||||
|
||||
current_verion = self._get_current_db_version()
|
||||
if current_verion is None:
|
||||
self._run_query(
|
||||
@ -192,6 +217,7 @@ class SqliteLogger(BaseLogger):
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
source: Union[str, Agent],
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
@ -208,10 +234,16 @@ class SqliteLogger(BaseLogger):
|
||||
else:
|
||||
response_messages = json.dumps(to_dict(response), indent=4)
|
||||
|
||||
source_name = None
|
||||
if isinstance(source, str):
|
||||
source_name = source
|
||||
else:
|
||||
source_name = source.name
|
||||
|
||||
query = """
|
||||
INSERT INTO chat_completions (
|
||||
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
args = (
|
||||
invocation_id,
|
||||
@ -224,6 +256,7 @@ class SqliteLogger(BaseLogger):
|
||||
cost,
|
||||
start_time,
|
||||
end_time,
|
||||
source_name,
|
||||
)
|
||||
|
||||
self._run_query(query=query, args=args)
|
||||
@ -335,6 +368,24 @@ class SqliteLogger(BaseLogger):
|
||||
)
|
||||
self._run_query(query=query, args=args)
|
||||
|
||||
def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
|
||||
|
||||
if self.con is None:
|
||||
return
|
||||
|
||||
query = """
|
||||
INSERT INTO function_calls (source_id, source_name, function_name, args, returns, timestamp) VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
query_args: Tuple[Any, ...] = (
|
||||
id(source),
|
||||
source.name if hasattr(source, "name") else source,
|
||||
function.__name__,
|
||||
safe_serialize(args),
|
||||
safe_serialize(returns),
|
||||
get_current_ts(),
|
||||
)
|
||||
self._run_query(query=query, args=query_args)
|
||||
|
||||
def log_new_client(
|
||||
self, client: Union[AzureOpenAI, OpenAI, GeminiClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
|
||||
) -> None:
|
||||
|
||||
@ -319,6 +319,7 @@ class OpenAIWrapper:
|
||||
"""A wrapper class for openai client."""
|
||||
|
||||
extra_kwargs = {
|
||||
"agent",
|
||||
"cache",
|
||||
"cache_seed",
|
||||
"filter_func",
|
||||
@ -542,6 +543,7 @@ class OpenAIWrapper:
|
||||
Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided,
|
||||
then the cache_seed argument is ignored. If this argument is not provided or None,
|
||||
then the cache_seed argument is used.
|
||||
- agent (AbstractAgent | None): The object responsible for creating a completion if an agent.
|
||||
- (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41.
|
||||
An integer cache_seed is useful when implementing "controlled randomness" for the completion.
|
||||
None for no caching.
|
||||
@ -589,6 +591,7 @@ class OpenAIWrapper:
|
||||
cache = extra_kwargs.get("cache")
|
||||
filter_func = extra_kwargs.get("filter_func")
|
||||
context = extra_kwargs.get("context")
|
||||
agent = extra_kwargs.get("agent")
|
||||
|
||||
total_usage = None
|
||||
actual_usage = None
|
||||
@ -626,6 +629,7 @@ class OpenAIWrapper:
|
||||
invocation_id=invocation_id,
|
||||
client_id=id(client),
|
||||
wrapper_id=id(self),
|
||||
agent=agent,
|
||||
request=params,
|
||||
response=response,
|
||||
is_cached=1,
|
||||
@ -658,6 +662,7 @@ class OpenAIWrapper:
|
||||
invocation_id=invocation_id,
|
||||
client_id=id(client),
|
||||
wrapper_id=id(self),
|
||||
agent=agent,
|
||||
request=params,
|
||||
response=f"error_code:{error_code}, config {i} failed",
|
||||
is_cached=0,
|
||||
@ -688,6 +693,7 @@ class OpenAIWrapper:
|
||||
invocation_id=invocation_id,
|
||||
client_id=id(client),
|
||||
wrapper_id=id(self),
|
||||
agent=agent,
|
||||
request=params,
|
||||
response=response,
|
||||
is_cached=0,
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, TypeVar, Union
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
@ -20,6 +20,8 @@ logger = logging.getLogger(__name__)
|
||||
autogen_logger = None
|
||||
is_logging = False
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def start(
|
||||
logger: Optional[BaseLogger] = None,
|
||||
@ -56,6 +58,7 @@ def log_chat_completion(
|
||||
invocation_id: uuid.UUID,
|
||||
client_id: int,
|
||||
wrapper_id: int,
|
||||
agent: Union[str, Agent],
|
||||
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
|
||||
response: Union[str, ChatCompletion],
|
||||
is_cached: int,
|
||||
@ -67,7 +70,7 @@ def log_chat_completion(
|
||||
return
|
||||
|
||||
autogen_logger.log_chat_completion(
|
||||
invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time
|
||||
invocation_id, client_id, wrapper_id, agent, request, response, is_cached, cost, start_time
|
||||
)
|
||||
|
||||
|
||||
@ -87,6 +90,14 @@ def log_event(source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) ->
|
||||
autogen_logger.log_event(source, name, **kwargs)
|
||||
|
||||
|
||||
def log_function_use(agent: Union[str, Agent], function: F, args: Dict[str, Any], returns: any):
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] log_function_use: autogen logger is None")
|
||||
return
|
||||
|
||||
autogen_logger.log_function_use(agent, function, args, returns)
|
||||
|
||||
|
||||
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
|
||||
if autogen_logger is None:
|
||||
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
|
||||
|
||||
533
notebook/agentchat_agentops.ipynb
Normal file
533
notebook/agentchat_agentops.ipynb
Normal file
@ -0,0 +1,533 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "abb8a01d85d8b146",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": [
|
||||
"# AgentOps"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a447802c88c8a240",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"[AgentOps](https://agentops.ai/?=autogen) provides session replays, metrics, and monitoring for AI agents.\n",
|
||||
"\n",
|
||||
"At a high level, AgentOps gives you the ability to monitor LLM calls, costs, latency, agent failures, multi-agent interactions, tool usage, session-wide statistics, and more. For more info, check out the [AgentOps Repo](https://github.com/AgentOps-AI/agentops).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b354c068",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dashboard\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "38182a5296dceb34",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Adding AgentOps to an existing Autogen service.\n",
|
||||
"To get started, you'll need to install the AgentOps package and set an API key.\n",
|
||||
"\n",
|
||||
"AgentOps automatically configures itself when it's initialized. This means your agents will be tracked and logged to your AgentOps account right away."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8d9451f4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"````{=mdx}\n",
|
||||
":::info Requirements\n",
|
||||
"Some extra dependencies are needed for this notebook, which can be installed via pip:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install pyautogen agentops\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For more information, please refer to the [installation guide](/docs/installation/).\n",
|
||||
":::\n",
|
||||
"````"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6be9e11620b0e8d6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Set an API key\n",
|
||||
"\n",
|
||||
"By default, the AgentOps `init()` function will look for an environment variable named `AGENTOPS_API_KEY`. Alternatively, you can pass one in as an optional parameter.\n",
|
||||
"\n",
|
||||
"Create an account and API key at [AgentOps.ai](https://agentops.ai/)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f31a28d20a13b377",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-31T22:48:27.679318Z",
|
||||
"start_time": "2024-05-31T22:48:26.192071Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🖇 AgentOps: \u001B[34m\u001B[34mSession Replay: https://app.agentops.ai/drilldown?session_id=8bfaeed1-fd51-4c68-b3ec-276b1a3ce8a4\u001B[0m\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "UUID('8bfaeed1-fd51-4c68-b3ec-276b1a3ce8a4')"
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import agentops\n",
|
||||
"\n",
|
||||
"from autogen import ConversableAgent, UserProxyAgent, config_list_from_json\n",
|
||||
"\n",
|
||||
"agentops.init(api_key=\"7c94212b-b89d-47a6-a20c-23b2077d3226\") # or agentops.init(api_key=\"...\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4dd8f461ccd9cbef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Autogen will now start automatically tracking\n",
|
||||
"- LLM prompts and completions\n",
|
||||
"- Token usage and costs\n",
|
||||
"- Agent names and actions\n",
|
||||
"- Correspondence between agents\n",
|
||||
"- Tool usage\n",
|
||||
"- Errors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "712315c520536eb8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Simple Chat Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "66d68e66e9f4a677",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-31T22:48:32.813123Z",
|
||||
"start_time": "2024-05-31T22:48:27.677564Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[33magent\u001B[0m (to user):\n",
|
||||
"\n",
|
||||
"How can I help you today?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[33muser\u001B[0m (to agent):\n",
|
||||
"\n",
|
||||
"2+2\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33magent\u001B[0m (to user):\n",
|
||||
"\n",
|
||||
"2 + 2 equals 4.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🖇 AgentOps: This run's cost $0.000960\n",
|
||||
"🖇 AgentOps: \u001B[34m\u001B[34mSession Replay: https://app.agentops.ai/drilldown?session_id=8bfaeed1-fd51-4c68-b3ec-276b1a3ce8a4\u001B[0m\u001B[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import agentops\n",
|
||||
"\n",
|
||||
"# When initializing AgentOps, you can pass in optional tags to help filter sessions\n",
|
||||
"agentops.init(tags=[\"simple-autogen-example\"])\n",
|
||||
"\n",
|
||||
"# Create the agent that uses the LLM.\n",
|
||||
"config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
|
||||
"assistant = ConversableAgent(\"agent\", llm_config={\"config_list\": config_list})\n",
|
||||
"\n",
|
||||
"# Create the agent that represents the user in the conversation.\n",
|
||||
"user_proxy = UserProxyAgent(\"user\", code_execution_config=False)\n",
|
||||
"\n",
|
||||
"# Let the assistant start the conversation. It will end when the user types exit.\n",
|
||||
"assistant.initiate_chat(user_proxy, message=\"How can I help you today?\")\n",
|
||||
"\n",
|
||||
"# Close your AgentOps session to indicate that it completed.\n",
|
||||
"agentops.end_session(\"Success\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2217ed0f930cfcaa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can view data on this run at [app.agentops.ai](https://app.agentops.ai). \n",
|
||||
"\n",
|
||||
"The dashboard will display LLM events for each message sent by each agent, including those made by the human user."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "cbd689b0f5617013"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fd78f1a816276cb7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tool Example\n",
|
||||
"AgentOps also tracks when Autogen agents use tools. You can find more information on this example in [tool-use.ipynb](https://github.com/microsoft/autogen/blob/main/website/docs/tutorial/tool-use.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "3498aa6176c799ff",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-05-31T22:48:35.808674Z",
|
||||
"start_time": "2024-05-31T22:48:32.813225Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🖇 AgentOps: \u001B[34m\u001B[34mSession Replay: https://app.agentops.ai/drilldown?session_id=880c206b-751e-4c23-9313-8684537fc04d\u001B[0m\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"What is (1423 - 123) / 3 + (32 + 23) * 5?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Suggested tool call (call_aINcGyo0Xkrh9g7buRuhyCz0): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 1423,\n",
|
||||
" \"b\": 123,\n",
|
||||
" \"operator\": \"-\"\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_aINcGyo0Xkrh9g7buRuhyCz0) *****\u001B[0m\n",
|
||||
"1300\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Suggested tool call (call_prJGf8V0QVT7cbD91e0Fcxpb): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 1300,\n",
|
||||
" \"b\": 3,\n",
|
||||
" \"operator\": \"/\"\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_prJGf8V0QVT7cbD91e0Fcxpb) *****\u001B[0m\n",
|
||||
"433\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/braelynboynton/Developer/agentops/autogen/autogen/agentchat/conversable_agent.py:2489: UserWarning: Function 'calculator' is being overridden.\n",
|
||||
" warnings.warn(f\"Function '{tool_sig['function']['name']}' is being overridden.\", UserWarning)\n",
|
||||
"/Users/braelynboynton/Developer/agentops/autogen/autogen/agentchat/conversable_agent.py:2408: UserWarning: Function 'calculator' is being overridden.\n",
|
||||
" warnings.warn(f\"Function '{name}' is being overridden.\", UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Suggested tool call (call_CUIgHRsySLjayDKuUphI1TGm): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 32,\n",
|
||||
" \"b\": 23,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_CUIgHRsySLjayDKuUphI1TGm) *****\u001B[0m\n",
|
||||
"55\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Suggested tool call (call_L7pGtBLUf9V0MPL90BASyesr): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 55,\n",
|
||||
" \"b\": 5,\n",
|
||||
" \"operator\": \"*\"\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_L7pGtBLUf9V0MPL90BASyesr) *****\u001B[0m\n",
|
||||
"275\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Suggested tool call (call_Ygo6p4XfcxRjkYBflhG3UVv6): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 433,\n",
|
||||
" \"b\": 275,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_Ygo6p4XfcxRjkYBflhG3UVv6) *****\u001B[0m\n",
|
||||
"708\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"The result of the calculation is 708.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"🖇 AgentOps: This run's cost $0.001800\n",
|
||||
"🖇 AgentOps: \u001B[34m\u001B[34mSession Replay: https://app.agentops.ai/drilldown?session_id=880c206b-751e-4c23-9313-8684537fc04d\u001B[0m\u001B[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from typing import Annotated, Literal\n",
|
||||
"\n",
|
||||
"from autogen import ConversableAgent, config_list_from_json, register_function\n",
|
||||
"\n",
|
||||
"agentops.start_session(tags=[\"autogen-tool-example\"])\n",
|
||||
"\n",
|
||||
"Operator = Literal[\"+\", \"-\", \"*\", \"/\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def calculator(a: int, b: int, operator: Annotated[Operator, \"operator\"]) -> int:\n",
|
||||
" if operator == \"+\":\n",
|
||||
" return a + b\n",
|
||||
" elif operator == \"-\":\n",
|
||||
" return a - b\n",
|
||||
" elif operator == \"*\":\n",
|
||||
" return a * b\n",
|
||||
" elif operator == \"/\":\n",
|
||||
" return int(a / b)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"Invalid operator\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
|
||||
"\n",
|
||||
"# Create the agent that uses the LLM.\n",
|
||||
"assistant = ConversableAgent(\n",
|
||||
" name=\"Assistant\",\n",
|
||||
" system_message=\"You are a helpful AI assistant. \"\n",
|
||||
" \"You can help with simple calculations. \"\n",
|
||||
" \"Return 'TERMINATE' when the task is done.\",\n",
|
||||
" llm_config={\"config_list\": config_list},\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# The user proxy agent is used for interacting with the assistant agent\n",
|
||||
"# and executes tool calls.\n",
|
||||
"user_proxy = ConversableAgent(\n",
|
||||
" name=\"User\",\n",
|
||||
" llm_config=False,\n",
|
||||
" is_termination_msg=lambda msg: msg.get(\"content\") is not None and \"TERMINATE\" in msg[\"content\"],\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assistant.register_for_llm(name=\"calculator\", description=\"A simple calculator\")(calculator)\n",
|
||||
"user_proxy.register_for_execution(name=\"calculator\")(calculator)\n",
|
||||
"\n",
|
||||
"# Register the calculator function to the two agents.\n",
|
||||
"register_function(\n",
|
||||
" calculator,\n",
|
||||
" caller=assistant, # The assistant agent can suggest calls to the calculator.\n",
|
||||
" executor=user_proxy, # The user proxy agent can execute the calculator calls.\n",
|
||||
" name=\"calculator\", # By default, the function name is used as the tool name.\n",
|
||||
" description=\"A simple calculator\", # A description of the tool.\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Let the assistant start the conversation. It will end when the user types exit.\n",
|
||||
"user_proxy.initiate_chat(assistant, message=\"What is (1423 - 123) / 3 + (32 + 23) * 5?\")\n",
|
||||
"\n",
|
||||
"agentops.end_session(\"Success\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b4edf8e70d17267",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can see your run in action at [app.agentops.ai](https://app.agentops.ai). In this example, the AgentOps dashboard will show:\n",
|
||||
"- Agents talking to each other\n",
|
||||
"- Each use of the `calculator` tool\n",
|
||||
"- Each call to OpenAI for LLM use"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "a922a52ab5fce31"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"front_matter": {
|
||||
"description": "Use AgentOps to simplify the development process and monitor your agents in production.",
|
||||
"tags": [
|
||||
"monitoring",
|
||||
"debugging"
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -40,7 +40,7 @@
|
||||
" filter_dict={\"tags\": [\"gpt-4\"]}, # comment out to get all\n",
|
||||
")\n",
|
||||
"# When using a single openai endpoint, you can use the following:\n",
|
||||
"# config_list = [{\"model\": \"gpt-4\", \"api_key\": os.getenv(\"OPENAI_API_KEY\")}]\n"
|
||||
"# config_list = [{\"model\": \"gpt-4\", \"api_key\": os.getenv(\"OPENAI_API_KEY\")}]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -25,12 +25,12 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Logging session ID: 6e08f3e0-392b-434e-8b69-4ab36c4fcf99\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\u001B[33muser_proxy\u001B[0m (to assistant):\n",
|
||||
"\n",
|
||||
"What is the height of the Eiffel Tower? Only respond with the answer and terminate\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\u001B[33massistant\u001B[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"The height of the Eiffel Tower is approximately 330 meters.\n",
|
||||
"\n",
|
||||
@ -313,12 +313,12 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Logging session ID: ed493ebf-d78e-49f0-b832-69557276d557\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\u001B[33muser_proxy\u001B[0m (to assistant):\n",
|
||||
"\n",
|
||||
"What is the height of the Eiffel Tower? Only respond with the answer and terminate\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\u001B[33massistant\u001B[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"The height of the Eiffel Tower is 330 meters.\n",
|
||||
"TERMINATE\n",
|
||||
@ -328,7 +328,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"import autogen\n",
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
@ -19,6 +20,10 @@ from autogen.logger.file_logger import FileLogger
|
||||
is_windows = sys.platform.startswith("win")
|
||||
|
||||
|
||||
def dummy_function(param1: str, param2: int) -> Any:
|
||||
return param1 * param2
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows")
|
||||
@pytest.fixture
|
||||
def logger() -> FileLogger:
|
||||
@ -49,8 +54,19 @@ def test_log_chat_completion(logger: FileLogger):
|
||||
is_cached = 0
|
||||
cost = 0.5
|
||||
start_time = "2024-05-06 15:20:21.263231"
|
||||
agent = autogen.AssistantAgent(name="TestAgent", code_execution_config=False)
|
||||
|
||||
logger.log_chat_completion(invocation_id, client_id, wrapper_id, request, response, is_cached, cost, start_time)
|
||||
logger.log_chat_completion(
|
||||
invocation_id=invocation_id,
|
||||
client_id=client_id,
|
||||
wrapper_id=wrapper_id,
|
||||
request=request,
|
||||
response=response,
|
||||
is_cached=is_cached,
|
||||
cost=cost,
|
||||
start_time=start_time,
|
||||
source=agent,
|
||||
)
|
||||
|
||||
with open(logger.log_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
@ -63,6 +79,26 @@ def test_log_chat_completion(logger: FileLogger):
|
||||
assert log_data["is_cached"] == is_cached
|
||||
assert log_data["cost"] == cost
|
||||
assert log_data["start_time"] == start_time
|
||||
assert log_data["source_name"] == "TestAgent"
|
||||
assert isinstance(log_data["thread_id"], int)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows, reason="Skipping file logging tests on Windows")
|
||||
def test_log_function_use(logger: FileLogger):
|
||||
source = autogen.AssistantAgent(name="TestAgent", code_execution_config=False)
|
||||
func: Callable[[str, int], Any] = dummy_function
|
||||
args = {"foo": "bar"}
|
||||
returns = True
|
||||
|
||||
logger.log_function_use(source=source, function=func, args=args, returns=returns)
|
||||
|
||||
with open(logger.log_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 1
|
||||
log_data = json.loads(lines[0])
|
||||
assert log_data["source_name"] == "TestAgent"
|
||||
assert log_data["input_args"] == json.dumps(args)
|
||||
assert log_data["returns"] == json.dumps(returns)
|
||||
assert isinstance(log_data["thread_id"], int)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Callable
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -61,6 +62,11 @@ SAMPLE_CHAT_RESPONSE = json.loads(
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def dummy_function(param1: str, param2: int) -> Any:
|
||||
return param1 * param2
|
||||
|
||||
|
||||
###############################################################
|
||||
|
||||
|
||||
@ -84,6 +90,7 @@ def get_sample_chat_completion(response):
|
||||
"is_cached": 0,
|
||||
"cost": 0.347,
|
||||
"start_time": get_current_ts(),
|
||||
"agent": autogen.AssistantAgent(name="TestAgent", code_execution_config=False),
|
||||
}
|
||||
|
||||
|
||||
@ -103,7 +110,7 @@ def test_log_completion(response, expected_logged_response, db_connection):
|
||||
|
||||
query = """
|
||||
SELECT invocation_id, client_id, wrapper_id, request, response, is_cached,
|
||||
cost, start_time FROM chat_completions
|
||||
cost, start_time, source_name FROM chat_completions
|
||||
"""
|
||||
|
||||
for row in cur.execute(query):
|
||||
@ -115,6 +122,28 @@ def test_log_completion(response, expected_logged_response, db_connection):
|
||||
assert row["is_cached"] == sample_completion["is_cached"]
|
||||
assert row["cost"] == sample_completion["cost"]
|
||||
assert row["start_time"] == sample_completion["start_time"]
|
||||
assert row["source_name"] == "TestAgent"
|
||||
|
||||
|
||||
def test_log_function_use(db_connection):
|
||||
cur = db_connection.cursor()
|
||||
|
||||
source = autogen.AssistantAgent(name="TestAgent", code_execution_config=False)
|
||||
func: Callable[[str, int], Any] = dummy_function
|
||||
args = {"foo": "bar"}
|
||||
returns = True
|
||||
|
||||
autogen.runtime_logging.log_function_use(agent=source, function=func, args=args, returns=returns)
|
||||
|
||||
query = """
|
||||
SELECT source_id, source_name, function_name, args, returns, timestamp
|
||||
FROM function_calls
|
||||
"""
|
||||
|
||||
for row in cur.execute(query):
|
||||
assert row["source_name"] == "TestAgent"
|
||||
assert row["args"] == json.dumps(args)
|
||||
assert row["returns"] == json.dumps(returns)
|
||||
|
||||
|
||||
def test_log_new_agent(db_connection):
|
||||
|
||||
@ -100,6 +100,9 @@ Links to notebook examples:
|
||||
- Automatically Build Multi-agent System with AgentBuilder - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/autobuild_basic.ipynb)
|
||||
- Automatically Build Multi-agent System from Agent Library - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/autobuild_agent_library.ipynb)
|
||||
|
||||
### Observability
|
||||
- Track LLM calls, tool usage, actions and errors using AgentOps - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentops.ipynb)
|
||||
|
||||
## Enhanced Inferences
|
||||
|
||||
### Utilities
|
||||
|
||||
83
website/docs/ecosystem/agentops.md
Normal file
83
website/docs/ecosystem/agentops.md
Normal file
@ -0,0 +1,83 @@
|
||||
# AgentOps 🖇️
|
||||
|
||||

|
||||
|
||||
[AgentOps](https://agentops.ai/?=autogen) provides session replays, metrics, and monitoring for agents.
|
||||
|
||||
At a high level, AgentOps gives you the ability to monitor LLM calls, costs, latency, agent failures, multi-agent interactions, tool usage, session-wide statistics, and more. For more info, check out the [AgentOps Repo](https://github.com/AgentOps-AI/agentops).
|
||||
|
||||
<details open>
|
||||
<summary>Agent Dashboard</summary>
|
||||
<a href="https://app.agentops.ai?ref=gh">
|
||||
<img src="https://github.com/AgentOps-AI/agentops/assets/14807319/158e082a-9a7d-49b7-9b41-51a49a1f7d3d" style="width: 90%;" alt="Agent Dashboard"/>
|
||||
</a>
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Session Analytics</summary>
|
||||
<a href="https://app.agentops.ai?ref=gh">
|
||||
<img src="https://github.com/AgentOps-AI/agentops/assets/14807319/d7228019-1488-40d3-852f-a61e998658ad" style="width: 90%;" alt="Session Analytics"/>
|
||||
</a>
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Session Replays</summary>
|
||||
<a href="https://app.agentops.ai?ref=gh">
|
||||
<img src="https://github.com/AgentOps-AI/agentops/assets/14807319/561d59f3-c441-4066-914b-f6cfe32a598c" style="width: 90%;" alt="Session Replays"/>
|
||||
</a>
|
||||
</details>
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
AgentOps works seamlessly with applications built using Autogen.
|
||||
|
||||
1. **Install AgentOps**
|
||||
```bash
|
||||
pip install agentops
|
||||
```
|
||||
|
||||
2. **Create an API Key:**
|
||||
Create a user API key here: [Create API Key](https://app.agentops.ai/account)
|
||||
|
||||
3. **Configure Your Environment:**
|
||||
Add your API key to your environment variables
|
||||
|
||||
```
|
||||
AGENTOPS_API_KEY=<YOUR_AGENTOPS_API_KEY>
|
||||
```
|
||||
|
||||
4. **Initialize AgentOps**
|
||||
|
||||
To start tracking all available data on Autogen runs, simply add two lines of code before implementing Autogen.
|
||||
|
||||
```python
|
||||
import agentops
|
||||
agentops.init() # Or: agentops.init(api_key="your-api-key-here")
|
||||
```
|
||||
|
||||
After initializing AgentOps, Autogen will now start automatically tracking your agent runs.
|
||||
|
||||
## Features
|
||||
|
||||
- **LLM Costs**: Track spend with foundation model providers
|
||||
- **Replay Analytics**: Watch step-by-step agent execution graphs
|
||||
- **Recursive Thought Detection**: Identify when agents fall into infinite loops
|
||||
- **Custom Reporting:** Create custom analytics on agent performance
|
||||
- **Analytics Dashboard:** Monitor high level statistics about agents in development and production
|
||||
- **Public Model Testing**: Test your agents against benchmarks and leaderboards
|
||||
- **Custom Tests:** Run your agents against domain specific tests
|
||||
- **Time Travel Debugging**: Save snapshots of session states to rewind and replay agent runs from chosen checkpoints.
|
||||
- **Compliance and Security**: Create audit logs and detect potential threats such as profanity and PII leaks
|
||||
- **Prompt Injection Detection**: Identify potential code injection and secret leaks
|
||||
|
||||
## Autogen + AgentOps examples
|
||||
* [AgentChat with AgentOps Notebook](/docs/notebooks/agentchat_agentops)
|
||||
* [More AgentOps Examples](https://docs.agentops.ai/v1/quickstart)
|
||||
|
||||
## Extra links
|
||||
|
||||
- [🐦 Twitter](https://twitter.com/agentopsai/)
|
||||
- [📢 Discord](https://discord.gg/JHPt4C7r)
|
||||
- [🖇️ AgentOps Dashboard](https://app.agentops.ai/ref?=autogen)
|
||||
- [📙 Documentation](https://docs.agentops.ai/introduction)
|
||||
42
website/docs/topics/llm-observability.md
Normal file
42
website/docs/topics/llm-observability.md
Normal file
@ -0,0 +1,42 @@
|
||||
# LLM Observability
|
||||
|
||||
AutoGen supports advanced LLM observability and monitoring through built-in logging and partner providers.
|
||||
|
||||
## What is LLM Observability
|
||||
AI agent observability is the ability to monitor, measure, and understand the internal states and behaviors of AI agent systems.
|
||||
Observability is crucial for ensuring transparency, reliability, and accountability in your agent systems.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
### Agent Development in Terminal is Limited
|
||||
- Lose track of what your agents did in between executions
|
||||
- Parsing through terminal output searching for LLM completions
|
||||
- Printing “tool called”
|
||||
|
||||
### Agent Development Dashboards Enable More
|
||||
- Visual dashboard so you can see what your agents did in human-readable format
|
||||
- LLM calls are magically recorded - prompt, completion, timestamps for each - with one line of code
|
||||
- Agents and their events (including tool calls) are recorded with one more line of code
|
||||
- Errors are magically associated to its causal event
|
||||
- Record any other events to your session with two more lines of code
|
||||
- Tons of other useful data if you’re developing with supported agent frameworks: SDK version
|
||||
|
||||
## Compliance
|
||||
|
||||
Observability and monitoring is critical to ensure AI agent systems adhere to laws and regulations in industries like finance and healthcare, preventing violations such as data breaches and privacy issues.
|
||||
|
||||
- Insights into AI decision-making, allowing organizations to explain outcomes and build trust with stakeholders.
|
||||
- Helps detect anomalies and unintended behaviors early, mitigating operational, financial, and reputational risks.
|
||||
- Ensures compliance with data privacy regulations, preventing unauthorized access and misuse of sensitive information.
|
||||
- Quick identification and response to compliance violations, supporting incident analysis and prevention.
|
||||
|
||||
## Available Observability Integrations
|
||||
|
||||
### Logging
|
||||
- Autogen SQLite and File Logger - [Tutorial](/docs/notebooks/agentchat_logging)
|
||||
|
||||
### Full-Service Partners
|
||||
Autogen is currently partnered with [AgentOps](https://agentops.ai) for seamless observability integration.
|
||||
|
||||
[Learn how to install AgentOps](/docs/notebooks/agentchat_agentops)
|
||||
@ -279,6 +279,7 @@
|
||||
" def __deepcopy__(self, memo):\n",
|
||||
" return self\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"config_list = [\n",
|
||||
" {\n",
|
||||
" \"model\": \"my-gpt-4-deployment\",\n",
|
||||
|
||||
@ -39,8 +39,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:48.858694Z",
|
||||
"start_time": "2024-04-25T01:49:48.854420Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Annotated, Literal\n",
|
||||
@ -91,8 +96,24 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:48.946697Z",
|
||||
"start_time": "2024-04-25T01:49:48.857869Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<function __main__.calculator(a: int, b: int, operator: Annotated[Literal['+', '-', '*', '/'], 'operator']) -> int>"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
@ -160,8 +181,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:48.953345Z",
|
||||
"start_time": "2024-04-25T01:49:48.947026Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen import register_function\n",
|
||||
@ -189,124 +215,128 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:57.947530Z",
|
||||
"start_time": "2024-04-25T01:49:48.953943Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"What is (44232 + 13312 / (232 - 32)) * 5?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_bACquf0OreI0VHh7rWiP6ZE7): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_4rElPoLggOYJmkUutbGaSTX1): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 13312,\n",
|
||||
" \"b\": 232 - 32,\n",
|
||||
" \"operator\": \"/\"\n",
|
||||
" \"a\": 232,\n",
|
||||
" \"b\": 32,\n",
|
||||
" \"operator\": \"-\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_bACquf0OreI0VHh7rWiP6ZE7\" *****\u001b[0m\n",
|
||||
"Error: Expecting ',' delimiter: line 1 column 26 (char 25)\n",
|
||||
" You argument should follow json format.\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_4rElPoLggOYJmkUutbGaSTX1) *****\u001B[0m\n",
|
||||
"200\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_2c0H5gzX9SWsJ05x7nEOVbav): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_SGtr8tK9A4iOCJGdCqkKR2Ov): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 13312,\n",
|
||||
" \"b\": 200,\n",
|
||||
" \"operator\": \"/\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_2c0H5gzX9SWsJ05x7nEOVbav\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_SGtr8tK9A4iOCJGdCqkKR2Ov) *****\u001B[0m\n",
|
||||
"66\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_ioceLhuKMpfU131E7TSQ8wCD): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_YsR95CM1Ice2GZ7ZoStYXI6M): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 44232,\n",
|
||||
" \"b\": 66,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_ioceLhuKMpfU131E7TSQ8wCD\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_YsR95CM1Ice2GZ7ZoStYXI6M) *****\u001B[0m\n",
|
||||
"44298\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_0rhx9vrbigcbqLssKLh4sS7j): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_oqZn4rTjyvXYcmjAXkvVaJm1): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"a\": 44298,\n",
|
||||
" \"b\": 5,\n",
|
||||
" \"operator\": \"*\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_0rhx9vrbigcbqLssKLh4sS7j\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_oqZn4rTjyvXYcmjAXkvVaJm1) *****\u001B[0m\n",
|
||||
"221490\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"The result of the calculation (44232 + 13312 / (232 - 32)) * 5 is 221490. \n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"The result of the calculation is 221490. TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
@ -325,8 +355,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:57.956063Z",
|
||||
"start_time": "2024-04-25T01:49:57.948882Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -334,7 +369,7 @@
|
||||
"221490"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -369,8 +404,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:57.956696Z",
|
||||
"start_time": "2024-04-25T01:49:57.952767Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -387,7 +427,7 @@
|
||||
" 'required': ['a', 'b', 'operator']}}}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -417,8 +457,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:57.965069Z",
|
||||
"start_time": "2024-04-25T01:49:57.958274Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
@ -459,8 +504,24 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:57.990811Z",
|
||||
"start_time": "2024-04-25T01:49:57.962315Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<function __main__.calculator(input: typing.Annotated[__main__.CalculatorInput, 'Input to the calculator.']) -> int>"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"assistant.register_for_llm(name=\"calculator\", description=\"A calculator tool that accepts nested expression as input\")(\n",
|
||||
" calculator\n",
|
||||
@ -477,8 +538,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:49:57.991342Z",
|
||||
"start_time": "2024-04-25T01:49:57.972554Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -504,7 +570,7 @@
|
||||
" 'required': ['input']}}}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 28,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -522,159 +588,192 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:50:17.808416Z",
|
||||
"start_time": "2024-04-25T01:49:57.975143Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"What is (1423 - 123) / 3 + (32 + 23) * 5?\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_t9By3vewGRoSLWsvdTR7p8Zo): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_Uu4diKtxlTfkwXuY6MmJEb4E): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 1423,\n",
|
||||
" \"b\": 123,\n",
|
||||
" \"operator\": \"-\"\n",
|
||||
" }\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": (1423 - 123) / 3,\n",
|
||||
" \"b\": (32 + 23) * 5,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_t9By3vewGRoSLWsvdTR7p8Zo\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_Uu4diKtxlTfkwXuY6MmJEb4E) *****\u001B[0m\n",
|
||||
"Error: Expecting value: line 1 column 29 (char 28)\n",
|
||||
" You argument should follow json format.\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"I apologize for the confusion, I seem to have made a mistake. Let me recalculate the expression properly.\n",
|
||||
"\n",
|
||||
"First, we need to do the calculations within the brackets. So, calculating (1423 - 123), (32 + 23), and then performing remaining operations.\n",
|
||||
"\u001B[32m***** Suggested tool call (call_mx3M3fNOwikFNoqSojDH1jIr): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 1423,\n",
|
||||
" \"b\": 123,\n",
|
||||
" \"operator\": \"-\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_mx3M3fNOwikFNoqSojDH1jIr) *****\u001B[0m\n",
|
||||
"1300\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_rhecyhVCo0Y8HPL193xOUPE6): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_hBAL2sYi6Y5ZtTHCNPCmxdN3): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 1300,\n",
|
||||
" \"b\": 3,\n",
|
||||
" \"operator\": \"/\"\n",
|
||||
" }\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 32,\n",
|
||||
" \"b\": 23,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_rhecyhVCo0Y8HPL193xOUPE6\" *****\u001b[0m\n",
|
||||
"433\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_zDpq9J5MYAsL7uS8cobOwa7S): calculator *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 32,\n",
|
||||
" \"b\": 23,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_zDpq9J5MYAsL7uS8cobOwa7S\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_hBAL2sYi6Y5ZtTHCNPCmxdN3) *****\u001B[0m\n",
|
||||
"55\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_mjDuVMojOIdaxmvDUIF4QtVi): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_wO3AP7EDeJvsVLCpvv5LohUa): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 55,\n",
|
||||
" \"b\": 5,\n",
|
||||
" \"operator\": \"*\"\n",
|
||||
" }\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 1300,\n",
|
||||
" \"b\": 3,\n",
|
||||
" \"operator\": \"/\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_mjDuVMojOIdaxmvDUIF4QtVi\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_wO3AP7EDeJvsVLCpvv5LohUa) *****\u001B[0m\n",
|
||||
"433\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Suggested tool call (call_kQ2hDhqem8BHNlaHaE9ezvvQ): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 55,\n",
|
||||
" \"b\": 5,\n",
|
||||
" \"operator\": \"*\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001B[32m***** Response from calling tool (call_kQ2hDhqem8BHNlaHaE9ezvvQ) *****\u001B[0m\n",
|
||||
"275\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_hpirkAGKOewZstsDOxL2sYNW): calculator *****\u001b[0m\n",
|
||||
"\u001B[32m***** Suggested tool call (call_1FLDUdvAZmjlSD7g5GFFJOpO): calculator *****\u001B[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 433,\n",
|
||||
" \"b\": 275,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
" }\n",
|
||||
" \"input\": {\n",
|
||||
" \"a\": 433,\n",
|
||||
" \"b\": 275,\n",
|
||||
" \"operator\": \"+\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\u001b[32m***************************************************************************\u001b[0m\n",
|
||||
"\u001B[32m***************************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001b[0m\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION calculator...\u001B[0m\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[33mUser\u001b[0m (to Assistant):\n",
|
||||
"\u001B[33mUser\u001B[0m (to Assistant):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_hpirkAGKOewZstsDOxL2sYNW\" *****\u001b[0m\n",
|
||||
"\u001B[32m***** Response from calling tool (call_1FLDUdvAZmjlSD7g5GFFJOpO) *****\u001B[0m\n",
|
||||
"708\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\u001B[32m**********************************************************************\u001B[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mAssistant\u001b[0m (to User):\n",
|
||||
"\n",
|
||||
"The result of the calculation is 708. \n",
|
||||
"\u001B[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001B[0m\n",
|
||||
"\u001B[33mAssistant\u001B[0m (to User):\n",
|
||||
"\n",
|
||||
"The calculation result of the expression (1423 - 123) / 3 + (32 + 23) * 5 is 708. Let's proceed to the next task.\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
@ -694,8 +793,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-04-25T01:50:17.818095Z",
|
||||
"start_time": "2024-04-25T01:50:17.808502Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -703,7 +807,7 @@
|
||||
"708"
|
||||
]
|
||||
},
|
||||
"execution_count": 31,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user