mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-02 10:50:03 +00:00
add tool call to chat completion agent (#35)
* add tool call to chat completion agent * refactor function executor; tool executor in chat completion agent * update example * update orchestrator chat demo * handle function execution result message type * format * temp fix for examples. * fix * update chat completion agent
This commit is contained in:
parent
2dc7af87ef
commit
04c30596ed
155
examples/orchestrator.py
Normal file
155
examples/orchestrator.py
Normal file
@ -0,0 +1,155 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Callable
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.function_executor._impl.in_process_function_executor import (
|
||||
InProcessFunctionExecutor,
|
||||
)
|
||||
from agnext.agent_components.model_client import OpenAI
|
||||
from agnext.agent_components.types import SystemMessage
|
||||
from agnext.application_components import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.core import Agent, AgentRuntime
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from tavily import TavilyClient
|
||||
from typing_extensions import Any, override
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class LoggingHandler(DefaultInterventionHandler): # type: ignore
|
||||
send_color = "\033[31m"
|
||||
response_color = "\033[34m"
|
||||
reset_color = "\033[0m"
|
||||
|
||||
@override
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: # type: ignore
|
||||
if sender is None:
|
||||
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
@override
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: # type: ignore
|
||||
if recipient is None:
|
||||
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ignore
|
||||
developer = ChatCompletionAgent(
|
||||
name="Developer",
|
||||
description="A developer that writes code.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("You are a Python developer.")],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
tester_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
description="A software tester that runs test cases and reports results.",
|
||||
instructions="You are a software tester that runs test cases and reports results.",
|
||||
)
|
||||
tester_oai_thread = openai.beta.threads.create()
|
||||
tester = OpenAIAssistantAgent(
|
||||
name="Tester",
|
||||
description="A software tester that runs test cases and reports results.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=tester_oai_assistant.id,
|
||||
thread_id=tester_oai_thread.id,
|
||||
)
|
||||
|
||||
def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The search results."]:
|
||||
"""Search the web."""
|
||||
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
||||
result = client.search(query) # type: ignore
|
||||
if result:
|
||||
return json.dumps(result, indent=2, ensure_ascii=False) # type: ignore
|
||||
return "No results found."
|
||||
|
||||
function_executor = InProcessFunctionExecutor(functions=[search])
|
||||
|
||||
product_manager = ChatCompletionAgent(
|
||||
name="ProductManager",
|
||||
description="A product manager that performs research and comes up with specs.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are a product manager good at translating customer needs into software specifications."),
|
||||
SystemMessage("You can use the search tool to find information on the web."),
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
function_executor=function_executor,
|
||||
)
|
||||
|
||||
planner = ChatCompletionAgent(
|
||||
name="Planner",
|
||||
description="A planner that organizes and schedules tasks.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("You are a planner of complex tasks.")],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
orchestrator = ChatCompletionAgent(
|
||||
name="Orchestrator",
|
||||
description="An orchestrator that coordinates the team.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
return OrchestratorChat(
|
||||
"OrchestratorChat",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator,
|
||||
planner=planner,
|
||||
specialists=[developer, product_manager, tester],
|
||||
)
|
||||
|
||||
|
||||
async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
chat = scenario(runtime)
|
||||
response = runtime.send_message(TextMessage(content=message, source=user), chat)
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
print((await response).content) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a orchestrator demo.")
|
||||
choices = {"software_development": software_development}
|
||||
parser.add_argument(
|
||||
"--scenario",
|
||||
choices=list(choices.keys()),
|
||||
help="The scenario to demo.",
|
||||
default="software_development",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user",
|
||||
default="Customer",
|
||||
help="The user to send the message. Default is 'Customer'.",
|
||||
)
|
||||
parser.add_argument("--message", help="The message to send.", required=True)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(run(args.message, args.user, choices[args.scenario]))
|
||||
@ -1,268 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.types import SystemMessage
|
||||
from agnext.application_components import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.core._agent import Agent
|
||||
from agnext.agent_components.model_client import OpenAI
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
from typing_extensions import Any, override
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class ConcatOutput(GroupChatOutput):
|
||||
def __init__(self) -> None:
|
||||
self._output = ""
|
||||
|
||||
def on_message_received(self, message: Any) -> None:
|
||||
match message:
|
||||
case TextMessage(content=content):
|
||||
self._output += content
|
||||
case _:
|
||||
...
|
||||
|
||||
def get_output(self) -> Any:
|
||||
return self._output
|
||||
|
||||
def reset(self) -> None:
|
||||
self._output = ""
|
||||
|
||||
|
||||
class LoggingHandler(DefaultInterventionHandler):
|
||||
send_color = "\033[31m"
|
||||
response_color = "\033[34m"
|
||||
reset_color = "\033[0m"
|
||||
|
||||
@override
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
|
||||
if sender is None:
|
||||
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
@override
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
|
||||
if recipient is None:
|
||||
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(
|
||||
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
async def group_chat(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
joe_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="Joe",
|
||||
instructions="You are a commedian named Joe. Make the audience laugh.",
|
||||
)
|
||||
joe_oai_thread = openai.beta.threads.create()
|
||||
joe = OpenAIAssistantAgent(
|
||||
name="Joe",
|
||||
description="Joe the commedian.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=joe_oai_assistant.id,
|
||||
thread_id=joe_oai_thread.id,
|
||||
)
|
||||
|
||||
cathy_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="Cathy",
|
||||
instructions="You are a poet named Cathy. Write beautiful poems.",
|
||||
)
|
||||
cathy_oai_thread = openai.beta.threads.create()
|
||||
cathy = OpenAIAssistantAgent(
|
||||
name="Cathy",
|
||||
description="Cathy the poet.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=cathy_oai_assistant.id,
|
||||
thread_id=cathy_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = GroupChat(
|
||||
"Host",
|
||||
"A round-robin chat room.",
|
||||
runtime,
|
||||
[joe, cathy],
|
||||
num_rounds=5,
|
||||
output=ConcatOutput(),
|
||||
)
|
||||
|
||||
response = runtime.send_message(TextMessage(content=message, source="host"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
await response
|
||||
|
||||
|
||||
async def orchestrator_oai_assistant(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
developer_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="Developer",
|
||||
instructions="You are a Python developer.",
|
||||
)
|
||||
developer_oai_thread = openai.beta.threads.create()
|
||||
developer = OpenAIAssistantAgent(
|
||||
name="Developer",
|
||||
description="A developer that writes code.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=developer_oai_assistant.id,
|
||||
thread_id=developer_oai_thread.id,
|
||||
)
|
||||
|
||||
product_manager_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-3.5-turbo",
|
||||
name="ProductManager",
|
||||
instructions="You are a product manager good at translating customer needs into software specifications.",
|
||||
)
|
||||
product_manager_oai_thread = openai.beta.threads.create()
|
||||
product_manager = OpenAIAssistantAgent(
|
||||
name="ProductManager",
|
||||
description="A product manager that plans and comes up with specs.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=product_manager_oai_assistant.id,
|
||||
thread_id=product_manager_oai_thread.id,
|
||||
)
|
||||
|
||||
planner_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
name="Planner",
|
||||
instructions="You are a planner of complex tasks.",
|
||||
)
|
||||
planner_oai_thread = openai.beta.threads.create()
|
||||
planner = OpenAIAssistantAgent(
|
||||
name="Planner",
|
||||
description="A planner that organizes and schedules tasks.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=planner_oai_assistant.id,
|
||||
thread_id=planner_oai_thread.id,
|
||||
)
|
||||
|
||||
orchestrator_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
name="Orchestrator",
|
||||
instructions="You are an orchestrator that coordinates the team to complete a complex task.",
|
||||
)
|
||||
orchestrator_oai_thread = openai.beta.threads.create()
|
||||
orchestrator = OpenAIAssistantAgent(
|
||||
name="Orchestrator",
|
||||
description="An orchestrator that coordinates the team.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=orchestrator_oai_assistant.id,
|
||||
thread_id=orchestrator_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = OrchestratorChat(
|
||||
"OrchestratorChat",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator,
|
||||
planner=planner,
|
||||
specialists=[developer, product_manager],
|
||||
)
|
||||
|
||||
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).content) # type: ignore
|
||||
|
||||
|
||||
async def orchestrator_chat_completion(message: str) -> None:
|
||||
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
|
||||
|
||||
developer = ChatCompletionAgent(
|
||||
name="Developer",
|
||||
description="A developer that writes code.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("You are a Python developer.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
product_manager = ChatCompletionAgent(
|
||||
name="ProductManager",
|
||||
description="A product manager that plans and comes up with specs.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are a product manager good at translating customer needs into software specifications.")
|
||||
],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
planner = ChatCompletionAgent(
|
||||
name="Planner",
|
||||
description="A planner that organizes and schedules tasks.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("You are a planner of complex tasks.")],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
orchestrator = ChatCompletionAgent(
|
||||
name="Orchestrator",
|
||||
description="An orchestrator that coordinates the team.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.")
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
)
|
||||
|
||||
chat = OrchestratorChat(
|
||||
"OrchestratorChat",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator,
|
||||
planner=planner,
|
||||
specialists=[developer, product_manager],
|
||||
)
|
||||
|
||||
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
print((await response).content) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a pattern demo.")
|
||||
choices = {
|
||||
"group_chat": group_chat,
|
||||
"orchestrator_oai_assistant": orchestrator_oai_assistant,
|
||||
"orchestrator_chat_completion": orchestrator_chat_completion,
|
||||
}
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
choices=list(choices.keys()),
|
||||
help="The pattern to demo.",
|
||||
)
|
||||
parser.add_argument("--message", help="The message to send.")
|
||||
args = parser.parse_args()
|
||||
asyncio.run(choices[args.pattern](args.message))
|
||||
@ -72,7 +72,7 @@ disallow_untyped_decorators = true
|
||||
disallow_any_unimported = true
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src", "examples", "tests"]
|
||||
include = ["src", "tests"]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
@ -1,4 +1,9 @@
|
||||
from ._base import FunctionExecutor, FunctionInfo, into_function_definition
|
||||
from ._base import Function, FunctionExecutor, into_function_signature
|
||||
from ._impl.in_process_function_executor import InProcessFunctionExecutor
|
||||
|
||||
__all__ = ["FunctionExecutor", "FunctionInfo", "into_function_definition", "InProcessFunctionExecutor"]
|
||||
__all__ = [
|
||||
"FunctionExecutor",
|
||||
"Function",
|
||||
"into_function_signature",
|
||||
"InProcessFunctionExecutor",
|
||||
]
|
||||
|
||||
@ -4,7 +4,13 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
from ..function_utils import get_function_schema
|
||||
from ..types import FunctionDefinition
|
||||
from ..types import FunctionSignature
|
||||
|
||||
|
||||
class Function(TypedDict):
|
||||
func: Required[Callable[..., Any]]
|
||||
name: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -12,25 +18,23 @@ class FunctionExecutor(Protocol):
|
||||
async def execute_function(self, function_name: str, arguments: Dict[str, Any]) -> str: ...
|
||||
|
||||
@property
|
||||
def functions(self) -> Sequence[str]: ...
|
||||
def functions(self) -> Sequence[Function]: ...
|
||||
|
||||
@property
|
||||
def function_signatures(self) -> Sequence[FunctionSignature]:
|
||||
return [into_function_signature(func) for func in self.functions]
|
||||
|
||||
|
||||
class FunctionInfo(TypedDict):
|
||||
func: Required[Callable[..., Any]]
|
||||
name: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
def into_function_definition(
|
||||
func_info: Union[FunctionInfo, FunctionDefinition, Callable[..., Any]],
|
||||
) -> FunctionDefinition:
|
||||
if isinstance(func_info, FunctionDefinition):
|
||||
return func_info
|
||||
elif isinstance(func_info, dict):
|
||||
name = func_info.get("name", func_info["func"].__name__)
|
||||
description = func_info.get("description", "")
|
||||
parameters = get_function_schema(func_info["func"], description="", name="")["function"]["parameters"]
|
||||
return FunctionDefinition(name=name, description=description, parameters=parameters)
|
||||
def into_function_signature(
|
||||
func_data: Union[Function, FunctionSignature, Callable[..., Any]],
|
||||
) -> FunctionSignature:
|
||||
if isinstance(func_data, FunctionSignature):
|
||||
return func_data
|
||||
elif isinstance(func_data, dict):
|
||||
name = func_data.get("name", func_data["func"].__name__)
|
||||
description = func_data.get("description", "")
|
||||
parameters = get_function_schema(func_data["func"], description="", name="")["function"]["parameters"]
|
||||
return FunctionSignature(name=name, description=description, parameters=parameters)
|
||||
else:
|
||||
parameters = get_function_schema(func_info, description="", name="")["function"]["parameters"]
|
||||
return FunctionDefinition(name=func_info.__name__, description="", parameters=parameters)
|
||||
parameters = get_function_schema(func_data, description="", name="")["function"]["parameters"]
|
||||
return FunctionSignature(name=func_data.__name__, description="", parameters=parameters)
|
||||
|
||||
@ -1,40 +1,53 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
from .._base import FunctionExecutor, FunctionInfo
|
||||
from .._base import Function, FunctionExecutor
|
||||
|
||||
|
||||
class InProcessFunctionExecutor(FunctionExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
functions: Sequence[Union[Callable[..., Any], FunctionInfo]] = [],
|
||||
functions: Sequence[Union[Callable[..., Any], Function]] = [],
|
||||
) -> None:
|
||||
def _name(func: Union[Callable[..., Any], FunctionInfo]) -> str:
|
||||
def _name(func: Union[Callable[..., Any], Function]) -> str:
|
||||
if isinstance(func, dict):
|
||||
return func.get("name", func["func"].__name__)
|
||||
else:
|
||||
return func.__name__
|
||||
|
||||
def _func(func: Union[Callable[..., Any], FunctionInfo]) -> Any:
|
||||
def _func(func: Union[Callable[..., Any], Function]) -> Any:
|
||||
if isinstance(func, dict):
|
||||
return func.get("func")
|
||||
else:
|
||||
return func
|
||||
|
||||
self._functions = dict([(_name(x), _func(x)) for x in functions])
|
||||
def _description(func: Union[Callable[..., Any], Function]) -> str:
|
||||
if isinstance(func, dict):
|
||||
return func.get("description", "")
|
||||
else:
|
||||
return ""
|
||||
|
||||
self._functions: Dict[str, Function] = dict()
|
||||
for func in functions:
|
||||
name = _name(func)
|
||||
self._functions[name] = Function(
|
||||
func=_func(func),
|
||||
name=name,
|
||||
description=_description(func),
|
||||
)
|
||||
|
||||
async def execute_function(self, function_name: str, arguments: dict[str, Any]) -> str:
|
||||
if function_name in self._functions:
|
||||
function = self._functions[function_name]
|
||||
function = self._functions[function_name]["func"]
|
||||
if asyncio.iscoroutinefunction(function):
|
||||
return str(function(**arguments))
|
||||
else:
|
||||
return await asyncio.get_event_loop().run_in_executor(None, functools.partial(function, **arguments))
|
||||
|
||||
raise ValueError(f"Function {function_name} not found")
|
||||
raise ValueError(f"Function {function_name} not found.")
|
||||
|
||||
@property
|
||||
def functions(self) -> Sequence[str]:
|
||||
return list(self._functions.keys())
|
||||
def functions(self) -> Sequence[Function]:
|
||||
return list(self._functions.values())
|
||||
|
||||
@ -1,9 +1,21 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/function_utils.py
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/function_utils.py
|
||||
# Credit to original authors
|
||||
|
||||
import inspect
|
||||
from logging import getLogger
|
||||
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Literal
|
||||
@ -74,7 +86,9 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
def get_param_annotations(typed_signature: inspect.Signature) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
def get_param_annotations(
|
||||
typed_signature: inspect.Signature,
|
||||
) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
|
||||
"""Get the type annotations of the parameters of a function
|
||||
|
||||
Args:
|
||||
|
||||
@ -11,7 +11,7 @@ from typing_extensions import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..types import CreateResult, FunctionDefinition, LLMMessage, RequestUsage
|
||||
from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
class ModelCapabilities(TypedDict, total=False):
|
||||
@ -26,7 +26,7 @@ class ModelClient(Protocol):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
functions: Sequence[FunctionSignature] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
@ -36,7 +36,7 @@ class ModelClient(Protocol):
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
functions: Sequence[FunctionSignature] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
|
||||
@ -40,8 +40,8 @@ from ..types import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
FunctionExecutionResultMessage,
|
||||
FunctionSignature,
|
||||
LLMMessage,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
@ -250,7 +250,7 @@ class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False)
|
||||
|
||||
|
||||
def convert_functions(
|
||||
functions: Sequence[FunctionDefinition],
|
||||
functions: Sequence[FunctionSignature],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for func in functions:
|
||||
@ -304,7 +304,7 @@ class BaseOpenAI(ModelClient):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
functions: Sequence[FunctionSignature] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult:
|
||||
@ -353,7 +353,10 @@ class BaseOpenAI(ModelClient):
|
||||
|
||||
if result.usage is not None:
|
||||
logger.info(
|
||||
LLMCallEvent(prompt_tokens=result.usage.prompt_tokens, completion_tokens=result.usage.completion_tokens)
|
||||
LLMCallEvent(
|
||||
prompt_tokens=result.usage.prompt_tokens,
|
||||
completion_tokens=result.usage.completion_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
usage = RequestUsage(
|
||||
@ -400,7 +403,7 @@ class BaseOpenAI(ModelClient):
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
functions: Sequence[FunctionDefinition] = [],
|
||||
functions: Sequence[FunctionSignature] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# File based from: https://github.com/microsoft/autogen/blob/main/autogen/_pydantic.py
|
||||
# File based from: https://github.com/microsoft/autogen/blob/47f905267245e143562abfb41fcba503a9e1d56d/autogen/_pydantic.py
|
||||
# Credit to original authors
|
||||
|
||||
|
||||
@ -14,7 +14,9 @@ PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
|
||||
|
||||
|
||||
def evaluate_forwardref(
|
||||
value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
if PYDANTIC_V1:
|
||||
from pydantic.typing import evaluate_forwardref as evaluate_forwardref_internal
|
||||
|
||||
@ -18,7 +18,7 @@ class FunctionCall:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionDefinition:
|
||||
class FunctionSignature:
|
||||
name: str
|
||||
parameters: Dict[str, Any]
|
||||
description: str
|
||||
|
||||
@ -1,10 +1,26 @@
|
||||
from typing import Any, Callable, Dict, List, Mapping
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Coroutine, Dict, List, Mapping, Tuple
|
||||
|
||||
from agnext.agent_components.function_executor import FunctionExecutor
|
||||
from agnext.agent_components.model_client import ModelClient
|
||||
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from agnext.agent_components.types import SystemMessage
|
||||
from agnext.agent_components.types import (
|
||||
FunctionCall,
|
||||
FunctionSignature,
|
||||
SystemMessage,
|
||||
)
|
||||
from agnext.chat.agents.base import BaseChatAgent
|
||||
from agnext.chat.types import Message, Reset, RespondNow, ResponseFormat, TextMessage
|
||||
from agnext.chat.types import (
|
||||
FunctionCallMessage,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
Message,
|
||||
Reset,
|
||||
RespondNow,
|
||||
ResponseFormat,
|
||||
TextMessage,
|
||||
)
|
||||
from agnext.chat.utils import convert_messages_to_llm_messages
|
||||
from agnext.core import AgentRuntime, CancellationToken
|
||||
|
||||
@ -17,13 +33,13 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
runtime: AgentRuntime,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ModelClient,
|
||||
tools: Dict[str, Callable[..., str]] | None = None,
|
||||
function_executor: FunctionExecutor | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._system_messages = system_messages
|
||||
self._client = model_client
|
||||
self._tools = tools or {}
|
||||
self._chat_messages: List[Message] = []
|
||||
self._function_executor = function_executor
|
||||
|
||||
@message_handler(TextMessage)
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
@ -36,20 +52,112 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
self._chat_messages = []
|
||||
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
if message.response_format == ResponseFormat.json_object:
|
||||
response_format = {"type": "json_object"}
|
||||
else:
|
||||
response_format = {"type": "text"}
|
||||
async def on_respond_now(
|
||||
self, message: RespondNow, cancellation_token: CancellationToken
|
||||
) -> TextMessage | FunctionCallMessage:
|
||||
# Get function signatures.
|
||||
function_signatures: List[FunctionSignature] = (
|
||||
[] if self._function_executor is None else list(self._function_executor.function_signatures)
|
||||
)
|
||||
|
||||
# Get a response from the model.
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
|
||||
extra_create_args={"response_format": response_format},
|
||||
functions=function_signatures,
|
||||
json_output=message.response_format == ResponseFormat.json_object,
|
||||
)
|
||||
|
||||
# If the agent has function executor, and the response is a list of
|
||||
# tool calls, iterate with itself until we get a response that is not a
|
||||
# list of tool calls.
|
||||
while (
|
||||
self._function_executor is not None
|
||||
and isinstance(response.content, list)
|
||||
and all(isinstance(x, FunctionCall) for x in response.content)
|
||||
):
|
||||
# Send a function call message to itself.
|
||||
response = await self._send_message(
|
||||
message=FunctionCallMessage(content=response.content, source=self.name),
|
||||
recipient=self,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
# Make an assistant message from the response.
|
||||
response = await self._client.create(
|
||||
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
|
||||
functions=function_signatures,
|
||||
json_output=message.response_format == ResponseFormat.json_object,
|
||||
)
|
||||
|
||||
final_response: Message
|
||||
if isinstance(response.content, str):
|
||||
return TextMessage(content=response.content, source=self.name)
|
||||
# If the response is a string, return a text message.
|
||||
final_response = TextMessage(content=response.content, source=self.name)
|
||||
elif isinstance(response.content, list) and all(isinstance(x, FunctionCall) for x in response.content):
|
||||
# If the response is a list of function calls, return a function call message.
|
||||
final_response = FunctionCallMessage(content=response.content, source=self.name)
|
||||
else:
|
||||
raise ValueError(f"Unexpected response: {response.content}")
|
||||
|
||||
# Add the response to the chat messages.
|
||||
self._chat_messages.append(final_response)
|
||||
|
||||
# Return the response.
|
||||
return final_response
|
||||
|
||||
@message_handler(FunctionCallMessage)
|
||||
async def on_tool_call_message(
|
||||
self, message: FunctionCallMessage, cancellation_token: CancellationToken
|
||||
) -> FunctionExecutionResultMessage:
|
||||
if self._function_executor is None:
|
||||
raise ValueError("Function executor is not set.")
|
||||
|
||||
# Add a tool call message.
|
||||
self._chat_messages.append(message)
|
||||
|
||||
# Execute the tool calls.
|
||||
results: List[FunctionExecutionResult] = []
|
||||
execution_futures: List[Coroutine[Any, Any, Tuple[str, str]]] = []
|
||||
for function_call in message.content:
|
||||
# Parse the arguments.
|
||||
try:
|
||||
arguments = json.loads(function_call.arguments)
|
||||
except json.JSONDecodeError:
|
||||
results.append(
|
||||
FunctionExecutionResult(
|
||||
content=f"Error: Could not parse arguments for function {function_call.name}.",
|
||||
call_id=function_call.id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
# Execute the function.
|
||||
future = self.execute_function(function_call.name, arguments, function_call.id)
|
||||
# Append the async result.
|
||||
execution_futures.append(future)
|
||||
if execution_futures:
|
||||
# Wait for all async results.
|
||||
execution_results = await asyncio.gather(*execution_futures)
|
||||
# Add the results.
|
||||
for execution_result, call_id in execution_results:
|
||||
results.append(FunctionExecutionResult(content=execution_result, call_id=call_id))
|
||||
|
||||
# Create a tool call result message.
|
||||
tool_call_result_msg = FunctionExecutionResultMessage(content=results, source=self.name)
|
||||
|
||||
# Add tool call result message.
|
||||
self._chat_messages.append(tool_call_result_msg)
|
||||
|
||||
# Return the results.
|
||||
return tool_call_result_msg
|
||||
|
||||
async def execute_function(self, name: str, args: Dict[str, Any], call_id: str) -> Tuple[str, str]:
|
||||
if self._function_executor is None:
|
||||
raise ValueError("Function executor is not set.")
|
||||
try:
|
||||
result = await self._function_executor.execute_function(name, args)
|
||||
except Exception as e:
|
||||
result = f"Error: {str(e)}"
|
||||
return (result, call_id)
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"description": self.description,
|
||||
|
||||
@ -6,6 +6,8 @@ from ...core import AgentRuntime, CancellationToken
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
|
||||
|
||||
__all__ = ["OrchestratorChat"]
|
||||
|
||||
|
||||
class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
|
||||
def __init__(
|
||||
@ -255,15 +257,65 @@ Please output an answer in pure JSON format according to the following schema. T
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
step_response = await self._send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
|
||||
)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
return json.loads(step_response.content)
|
||||
request = step_prompt
|
||||
while True:
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=request, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
step_response = await self._send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object),
|
||||
self._orchestrator,
|
||||
)
|
||||
# TODO: use typed dictionary.
|
||||
try:
|
||||
result = json.loads(str(step_response.content))
|
||||
except json.JSONDecodeError as e:
|
||||
request = f"Invalid JSON: {str(e)}"
|
||||
continue
|
||||
if "is_request_satisfied" not in result:
|
||||
request = "Missing key: is_request_satisfied"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["is_request_satisfied"], dict)
|
||||
or "answer" not in result["is_request_satisfied"]
|
||||
or "reason" not in result["is_request_satisfied"]
|
||||
):
|
||||
request = "Invalid value for key: is_request_satisfied, expected 'answer' and 'reason'"
|
||||
continue
|
||||
if "is_progress_being_made" not in result:
|
||||
request = "Missing key: is_progress_being_made"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["is_progress_being_made"], dict)
|
||||
or "answer" not in result["is_progress_being_made"]
|
||||
or "reason" not in result["is_progress_being_made"]
|
||||
):
|
||||
request = "Invalid value for key: is_progress_being_made, expected 'answer' and 'reason'"
|
||||
continue
|
||||
if "next_speaker" not in result:
|
||||
request = "Missing key: next_speaker"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["next_speaker"], dict)
|
||||
or "answer" not in result["next_speaker"]
|
||||
or "reason" not in result["next_speaker"]
|
||||
):
|
||||
request = "Invalid value for key: next_speaker, expected 'answer' and 'reason'"
|
||||
continue
|
||||
elif result["next_speaker"]["answer"] not in names:
|
||||
request = f"Invalid value for key: next_speaker, expected 'answer' in {names}"
|
||||
continue
|
||||
if "instruction_or_question" not in result:
|
||||
request = "Missing key: instruction_or_question"
|
||||
continue
|
||||
elif (
|
||||
not isinstance(result["instruction_or_question"], dict)
|
||||
or "answer" not in result["instruction_or_question"]
|
||||
or "reason" not in result["instruction_or_question"]
|
||||
):
|
||||
request = "Invalid value for key: instruction_or_question, expected 'answer' and 'reason'"
|
||||
continue
|
||||
return result
|
||||
|
||||
async def _rewrite_facts(self, facts: str, sender: str) -> str:
|
||||
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
|
||||
@ -293,15 +345,35 @@ Please output an answer in pure JSON format according to the following schema. T
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
educated_guess_response = await self._send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object), self._orchestrator
|
||||
)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
return json.loads(str(educated_guess_response.content))
|
||||
request = educated_guess_promt
|
||||
while True:
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(
|
||||
TextMessage(content=request, source=sender),
|
||||
self._orchestrator,
|
||||
)
|
||||
# Request a response.
|
||||
response = await self._send_message(
|
||||
RespondNow(response_format=ResponseFormat.json_object),
|
||||
self._orchestrator,
|
||||
)
|
||||
try:
|
||||
result = json.loads(str(response.content))
|
||||
except json.JSONDecodeError as e:
|
||||
request = f"Invalid JSON: {str(e)}"
|
||||
continue
|
||||
# TODO: use typed dictionary.
|
||||
if "has_educated_guesses" not in result:
|
||||
request = "Missing key: has_educated_guesses"
|
||||
continue
|
||||
if (
|
||||
not isinstance(result["has_educated_guesses"], dict)
|
||||
or "answer" not in result["has_educated_guesses"]
|
||||
or "reason" not in result["has_educated_guesses"]
|
||||
):
|
||||
request = "Invalid value for key: has_educated_guesses, expected 'answer' and 'reason'"
|
||||
continue
|
||||
return result
|
||||
|
||||
async def _rewrite_plan(self, team: str, sender: str) -> str:
|
||||
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
|
||||
|
||||
@ -2,8 +2,24 @@ from typing import List, Optional, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage
|
||||
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage
|
||||
from agnext.agent_components.types import (
|
||||
AssistantMessage,
|
||||
LLMMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.agent_components.types import (
|
||||
FunctionExecutionResult as FunctionExecutionResultType,
|
||||
)
|
||||
from agnext.agent_components.types import (
|
||||
FunctionExecutionResultMessage as FunctionExecutionResultMessageType,
|
||||
)
|
||||
from agnext.chat.types import (
|
||||
FunctionCallMessage,
|
||||
FunctionExecutionResultMessage,
|
||||
Message,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
)
|
||||
|
||||
|
||||
def convert_content_message_to_assistant_message(
|
||||
@ -20,7 +36,8 @@ def convert_content_message_to_assistant_message(
|
||||
return None
|
||||
elif handle_unrepresentable == "try_slice":
|
||||
return AssistantMessage(
|
||||
content="".join([x for x in message.content if isinstance(x, str)]), source=message.source
|
||||
content="".join([x for x in message.content if isinstance(x, str)]),
|
||||
source=message.source,
|
||||
)
|
||||
|
||||
|
||||
@ -41,8 +58,21 @@ def convert_content_message_to_user_message(
|
||||
raise NotImplementedError("Sliced function calls not yet implemented")
|
||||
|
||||
|
||||
def convert_tool_call_response_message(
|
||||
message: FunctionExecutionResultMessage,
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> Optional[FunctionExecutionResultMessageType]:
|
||||
match message:
|
||||
case FunctionExecutionResultMessage():
|
||||
return FunctionExecutionResultMessageType(
|
||||
content=[FunctionExecutionResultType(content=x.content, call_id=x.call_id) for x in message.content]
|
||||
)
|
||||
|
||||
|
||||
def convert_messages_to_llm_messages(
|
||||
messages: List[Message], self_name: str, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error"
|
||||
messages: List[Message],
|
||||
self_name: str,
|
||||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
|
||||
) -> List[LLMMessage]:
|
||||
result: List[LLMMessage] = []
|
||||
for message in messages:
|
||||
@ -63,6 +93,10 @@ def convert_messages_to_llm_messages(
|
||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||
if converted_message_2 is not None:
|
||||
result.append(converted_message_2)
|
||||
case FunctionExecutionResultMessage(_, source=source) if source == self_name:
|
||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
||||
if converted_message_3 is not None:
|
||||
result.append(converted_message_3)
|
||||
case _:
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
@ -3,12 +3,17 @@
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agnext.agent_components.code_executor import LocalCommandLineCodeExecutor, CodeBlock
|
||||
from agnext.agent_components.func_with_reqs import FunctionWithRequirements, with_requirements
|
||||
|
||||
import polars
|
||||
import pytest
|
||||
from agnext.agent_components.code_executor import (
|
||||
CodeBlock,
|
||||
LocalCommandLineCodeExecutor,
|
||||
)
|
||||
from agnext.agent_components.func_with_reqs import (
|
||||
FunctionWithRequirements,
|
||||
with_requirements,
|
||||
)
|
||||
|
||||
|
||||
def add_two_numbers(a: int, b: int) -> int:
|
||||
"""Add two numbers together."""
|
||||
@ -46,7 +51,9 @@ def function_missing_reqs() -> "polars.DataFrame":
|
||||
|
||||
def test_can_load_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[load_data])
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[load_data]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import load_data
|
||||
import polars
|
||||
|
||||
@ -65,7 +72,9 @@ print(data['name'][0])"""
|
||||
|
||||
def test_can_load_function() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers])
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[add_two_numbers]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import add_two_numbers
|
||||
print(add_two_numbers(1, 2))"""
|
||||
|
||||
@ -80,7 +89,9 @@ print(add_two_numbers(1, 2))"""
|
||||
|
||||
def test_fails_for_function_incorrect_import() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_import])
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[function_incorrect_import]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import function_incorrect_import
|
||||
function_incorrect_import()"""
|
||||
|
||||
@ -94,7 +105,9 @@ function_incorrect_import()"""
|
||||
|
||||
def test_fails_for_function_incorrect_dep() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[function_incorrect_dep])
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[function_incorrect_dep]
|
||||
)
|
||||
code = f"""from {executor.functions_module} import function_incorrect_dep
|
||||
function_incorrect_dep()"""
|
||||
|
||||
@ -106,10 +119,11 @@ function_incorrect_dep()"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
def test_formatted_prompt() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
executor = LocalCommandLineCodeExecutor(work_dir=temp_dir, functions=[add_two_numbers])
|
||||
executor = LocalCommandLineCodeExecutor(
|
||||
work_dir=temp_dir, functions=[add_two_numbers]
|
||||
)
|
||||
|
||||
result = executor.format_functions_for_prompt()
|
||||
assert (
|
||||
@ -140,7 +154,6 @@ def add_two_numbers(a: int, b: int) -> int:
|
||||
)
|
||||
|
||||
|
||||
|
||||
def test_can_load_str_function_with_reqs() -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
func = FunctionWithRequirements.from_str(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user