mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-25 05:59:19 +00:00
Introducing IOStream protocol and adding support for websockets (#1551)
* Introducing IOStream * bug fixing * polishing * refactoring * refactoring * refactoring * wip: async tests * websockets added * wip * merge with main * notebook added * FastAPI example added * wip * merge * getter/setter to iostream added * website/blog/2024-03-03-AutoGen-Update/img/dalle_gpt4v.png: convert to Git LFS * website/blog/2024-03-03-AutoGen-Update/img/gaia.png: convert to Git LFS * website/blog/2024-03-03-AutoGen-Update/img/teach.png: convert to Git LFS * add SSL support * wip * wip * exception handling added to on_connect() * refactoring: default iostream is being set in a context manager * test fix * polishing * polishing * polishing * fixed bug with new thread * polishing * a bit of refactoring and docs added * notebook added to docs * type checking added to CI * CI fix * CI fix * CI fix * polishing * obsolete todo comment removed * fixed precommit error --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
72994ea127
commit
78aa0eb220
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -65,7 +65,7 @@ jobs:
|
||||
- name: Coverage
|
||||
if: matrix.python-version == '3.10'
|
||||
run: |
|
||||
pip install -e .[test,redis]
|
||||
pip install -e .[test,redis,websockets]
|
||||
coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
|
||||
@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from .utils import consolidate_chat_info
|
||||
import datetime
|
||||
import warnings
|
||||
from ..io.base import IOStream
|
||||
from ..formatting_utils import colored
|
||||
|
||||
|
||||
@ -103,6 +104,8 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
|
||||
|
||||
|
||||
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
if "message" not in chat_info:
|
||||
warnings.warn(
|
||||
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
|
||||
@ -122,8 +125,8 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
|
||||
print_message = "Dict: " + str(message)
|
||||
elif message is None:
|
||||
print_message = "None"
|
||||
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
|
||||
print(
|
||||
iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
|
||||
iostream.print(
|
||||
colored(
|
||||
"Starting a new chat....",
|
||||
"blue",
|
||||
@ -131,9 +134,9 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
|
||||
flush=True,
|
||||
)
|
||||
if chat_info.get("verbose", False):
|
||||
print(colored("Message:\n" + print_message, "blue"), flush=True)
|
||||
print(colored("Carryover:\n" + print_carryover, "blue"), flush=True)
|
||||
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
|
||||
iostream.print(colored("Message:\n" + print_message, "blue"), flush=True)
|
||||
iostream.print(colored("Carryover:\n" + print_carryover, "blue"), flush=True)
|
||||
iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
|
||||
|
||||
|
||||
def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
|
||||
|
||||
@ -32,6 +32,7 @@ from ..function_utils import get_function_schema, load_basemodels_if_needed, ser
|
||||
from ..oai.client import ModelClient, OpenAIWrapper
|
||||
from ..runtime_logging import log_new_agent, logging_enabled
|
||||
from .agent import Agent, LLMAgent
|
||||
from ..io.base import IOStream
|
||||
from .chat import ChatResult, a_initiate_chats, initiate_chats
|
||||
from .utils import consolidate_chat_info, gather_usage_summary
|
||||
|
||||
@ -681,8 +682,9 @@ class ConversableAgent(LLMAgent):
|
||||
)
|
||||
|
||||
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
|
||||
iostream = IOStream.get_default()
|
||||
# print the message received
|
||||
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
|
||||
iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
|
||||
message = self._message_to_dict(message)
|
||||
|
||||
if message.get("tool_responses"): # Handle tool multi-call responses
|
||||
@ -698,9 +700,9 @@ class ConversableAgent(LLMAgent):
|
||||
id_key = "tool_call_id"
|
||||
id = message.get(id_key, "No id found")
|
||||
func_print = f"***** Response from calling {message['role']} ({id}) *****"
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(message["content"], flush=True)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
iostream.print(colored(func_print, "green"), flush=True)
|
||||
iostream.print(message["content"], flush=True)
|
||||
iostream.print(colored("*" * len(func_print), "green"), flush=True)
|
||||
else:
|
||||
content = message.get("content")
|
||||
if content is not None:
|
||||
@ -710,35 +712,35 @@ class ConversableAgent(LLMAgent):
|
||||
message["context"],
|
||||
self.llm_config and self.llm_config.get("allow_format_str_template", False),
|
||||
)
|
||||
print(content_str(content), flush=True)
|
||||
iostream.print(content_str(content), flush=True)
|
||||
if "function_call" in message and message["function_call"]:
|
||||
function_call = dict(message["function_call"])
|
||||
func_print = (
|
||||
f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****"
|
||||
)
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(
|
||||
iostream.print(colored(func_print, "green"), flush=True)
|
||||
iostream.print(
|
||||
"Arguments: \n",
|
||||
function_call.get("arguments", "(No arguments found)"),
|
||||
flush=True,
|
||||
sep="",
|
||||
)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
iostream.print(colored("*" * len(func_print), "green"), flush=True)
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
for tool_call in message["tool_calls"]:
|
||||
id = tool_call.get("id", "No tool call id found")
|
||||
function_call = dict(tool_call.get("function", {}))
|
||||
func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****"
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(
|
||||
iostream.print(colored(func_print, "green"), flush=True)
|
||||
iostream.print(
|
||||
"Arguments: \n",
|
||||
function_call.get("arguments", "(No arguments found)"),
|
||||
flush=True,
|
||||
sep="",
|
||||
)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
iostream.print(colored("*" * len(func_print), "green"), flush=True)
|
||||
|
||||
print("\n", "-" * 80, flush=True, sep="")
|
||||
iostream.print("\n", "-" * 80, flush=True, sep="")
|
||||
|
||||
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
|
||||
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
|
||||
@ -1229,6 +1231,7 @@ class ConversableAgent(LLMAgent):
|
||||
recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents.
|
||||
nr_messages_to_preserve: the number of newest messages to preserve in the chat history.
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
if recipient is None:
|
||||
if nr_messages_to_preserve:
|
||||
for key in self._oai_messages:
|
||||
@ -1238,7 +1241,7 @@ class ConversableAgent(LLMAgent):
|
||||
first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal]
|
||||
if "tool_responses" in first_msg_to_save:
|
||||
nr_messages_to_preserve_internal += 1
|
||||
print(
|
||||
iostream.print(
|
||||
f"Preserving one more message for {self.name} to not divide history between tool call and "
|
||||
f"tool response."
|
||||
)
|
||||
@ -1249,7 +1252,7 @@ class ConversableAgent(LLMAgent):
|
||||
else:
|
||||
self._oai_messages[recipient].clear()
|
||||
if nr_messages_to_preserve:
|
||||
print(
|
||||
iostream.print(
|
||||
colored(
|
||||
"WARNING: `nr_preserved_messages` is ignored when clearing chat history with a specific agent.",
|
||||
"yellow",
|
||||
@ -1323,8 +1326,19 @@ class ConversableAgent(LLMAgent):
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
"""Generate a reply using autogen.oai asynchronously."""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
def _generate_oai_reply(
|
||||
self, iostream: IOStream, *args: Any, **kwargs: Any
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
with IOStream.set_default(iostream):
|
||||
return self.generate_oai_reply(*args, **kwargs)
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
|
||||
None,
|
||||
functools.partial(
|
||||
_generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config
|
||||
),
|
||||
)
|
||||
|
||||
def _generate_code_execution_reply_using_executor(
|
||||
@ -1334,6 +1348,8 @@ class ConversableAgent(LLMAgent):
|
||||
config: Optional[Union[Dict, Literal[False]]] = None,
|
||||
):
|
||||
"""Generate a reply using code executor."""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
if config is not None:
|
||||
raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.")
|
||||
if self._code_execution_config is False:
|
||||
@ -1371,7 +1387,7 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
num_code_blocks = len(code_blocks)
|
||||
if num_code_blocks == 1:
|
||||
print(
|
||||
iostream.print(
|
||||
colored(
|
||||
f"\n>>>>>>>> EXECUTING CODE BLOCK (inferred language is {code_blocks[0].language})...",
|
||||
"red",
|
||||
@ -1379,7 +1395,7 @@ class ConversableAgent(LLMAgent):
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
iostream.print(
|
||||
colored(
|
||||
f"\n>>>>>>>> EXECUTING {num_code_blocks} CODE BLOCKS (inferred languages are [{', '.join([x.language for x in code_blocks])}])...",
|
||||
"red",
|
||||
@ -1631,6 +1647,8 @@ class ConversableAgent(LLMAgent):
|
||||
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
|
||||
should be terminated, and a human reply which can be a string, a dictionary, or None.
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
@ -1675,7 +1693,7 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# print the no_human_input_msg
|
||||
if no_human_input_msg:
|
||||
print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
|
||||
iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
|
||||
|
||||
# stop the conversation
|
||||
if reply == "exit":
|
||||
@ -1715,7 +1733,7 @@ class ConversableAgent(LLMAgent):
|
||||
# increment the consecutive_auto_reply_counter
|
||||
self._consecutive_auto_reply_counter[sender] += 1
|
||||
if self.human_input_mode != "NEVER":
|
||||
print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
|
||||
iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
|
||||
|
||||
return False, None
|
||||
|
||||
@ -1742,6 +1760,8 @@ class ConversableAgent(LLMAgent):
|
||||
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
|
||||
should be terminated, and a human reply which can be a string, a dictionary, or None.
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
@ -1786,7 +1806,7 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# print the no_human_input_msg
|
||||
if no_human_input_msg:
|
||||
print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
|
||||
iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
|
||||
|
||||
# stop the conversation
|
||||
if reply == "exit":
|
||||
@ -1826,7 +1846,7 @@ class ConversableAgent(LLMAgent):
|
||||
# increment the consecutive_auto_reply_counter
|
||||
self._consecutive_auto_reply_counter[sender] += 1
|
||||
if self.human_input_mode != "NEVER":
|
||||
print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
|
||||
iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
|
||||
|
||||
return False, None
|
||||
|
||||
@ -2001,7 +2021,9 @@ class ConversableAgent(LLMAgent):
|
||||
Returns:
|
||||
str: human input.
|
||||
"""
|
||||
reply = input(prompt)
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
reply = iostream.input(prompt)
|
||||
self._human_input.append(reply)
|
||||
return reply
|
||||
|
||||
@ -2016,8 +2038,8 @@ class ConversableAgent(LLMAgent):
|
||||
Returns:
|
||||
str: human input.
|
||||
"""
|
||||
reply = input(prompt)
|
||||
self._human_input.append(reply)
|
||||
loop = asyncio.get_running_loop()
|
||||
reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt))
|
||||
return reply
|
||||
|
||||
def run_code(self, code, **kwargs):
|
||||
@ -2038,12 +2060,14 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
def execute_code_blocks(self, code_blocks):
|
||||
"""Execute the code blocks and return the result."""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
logs_all = ""
|
||||
for i, code_block in enumerate(code_blocks):
|
||||
lang, code = code_block
|
||||
if not lang:
|
||||
lang = infer_lang(code)
|
||||
print(
|
||||
iostream.print(
|
||||
colored(
|
||||
f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...",
|
||||
"red",
|
||||
@ -2124,6 +2148,8 @@ class ConversableAgent(LLMAgent):
|
||||
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
|
||||
@ -2139,7 +2165,7 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# Try to execute the function
|
||||
if arguments is not None:
|
||||
print(
|
||||
iostream.print(
|
||||
colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"),
|
||||
flush=True,
|
||||
)
|
||||
@ -2152,7 +2178,7 @@ class ConversableAgent(LLMAgent):
|
||||
content = f"Error: Function {func_name} not found."
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
iostream.print(
|
||||
colored(f"\nInput arguments: {arguments}\nOutput:\n{content}", "magenta"),
|
||||
flush=True,
|
||||
)
|
||||
@ -2179,6 +2205,8 @@ class ConversableAgent(LLMAgent):
|
||||
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
|
||||
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
|
||||
@ -2194,7 +2222,7 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# Try to execute the function
|
||||
if arguments is not None:
|
||||
print(
|
||||
iostream.print(
|
||||
colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"),
|
||||
flush=True,
|
||||
)
|
||||
@ -2639,10 +2667,12 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
|
||||
"""Print the usage summary."""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
if self.client is None:
|
||||
print(f"No cost incurred from agent '{self.name}'.")
|
||||
iostream.print(f"No cost incurred from agent '{self.name}'.")
|
||||
else:
|
||||
print(f"Agent '{self.name}':")
|
||||
iostream.print(f"Agent '{self.name}':")
|
||||
self.client.print_usage_summary(mode)
|
||||
|
||||
def get_actual_usage(self) -> Union[None, Dict[str, int]]:
|
||||
|
||||
@ -5,9 +5,9 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from autogen.agentchat.agent import Agent
|
||||
from autogen.agentchat.conversable_agent import ConversableAgent
|
||||
|
||||
from .agent import Agent
|
||||
from .conversable_agent import ConversableAgent
|
||||
from ..io.base import IOStream
|
||||
from ..code_utils import content_str
|
||||
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
|
||||
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
|
||||
@ -257,22 +257,26 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
||||
|
||||
def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
|
||||
"""Manually select the next speaker."""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
if agents is None:
|
||||
agents = self.agents
|
||||
|
||||
print("Please select the next speaker from the following list:")
|
||||
iostream.print("Please select the next speaker from the following list:")
|
||||
_n_agents = len(agents)
|
||||
for i in range(_n_agents):
|
||||
print(f"{i+1}: {agents[i].name}")
|
||||
iostream.print(f"{i+1}: {agents[i].name}")
|
||||
try_count = 0
|
||||
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
|
||||
while try_count <= 3:
|
||||
try_count += 1
|
||||
if try_count >= 3:
|
||||
print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
|
||||
iostream.print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
|
||||
break
|
||||
try:
|
||||
i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
|
||||
i = iostream.input(
|
||||
"Enter the number of the next speaker (enter nothing or `q` to use auto selection): "
|
||||
)
|
||||
if i == "" or i == "q":
|
||||
break
|
||||
i = int(i)
|
||||
@ -281,7 +285,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
|
||||
iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
|
||||
return None
|
||||
|
||||
def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
|
||||
@ -740,6 +744,8 @@ class GroupChatManager(ConversableAgent):
|
||||
reply (dict): reply message dict to analyze.
|
||||
groupchat (GroupChat): GroupChat object.
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
reply_content = reply["content"]
|
||||
# Split the reply into words
|
||||
words = reply_content.split()
|
||||
@ -775,21 +781,21 @@ class GroupChatManager(ConversableAgent):
|
||||
# clear history
|
||||
if agent_to_memory_clear:
|
||||
if nr_messages_to_preserve:
|
||||
print(
|
||||
iostream.print(
|
||||
f"Clearing history for {agent_to_memory_clear.name} except last {nr_messages_to_preserve} messages."
|
||||
)
|
||||
else:
|
||||
print(f"Clearing history for {agent_to_memory_clear.name}.")
|
||||
iostream.print(f"Clearing history for {agent_to_memory_clear.name}.")
|
||||
agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve)
|
||||
else:
|
||||
if nr_messages_to_preserve:
|
||||
print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.")
|
||||
iostream.print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.")
|
||||
# clearing history for groupchat here
|
||||
temp = groupchat.messages[-nr_messages_to_preserve:]
|
||||
groupchat.messages.clear()
|
||||
groupchat.messages.extend(temp)
|
||||
else:
|
||||
print("Clearing history for all agents.")
|
||||
iostream.print("Clearing history for all agents.")
|
||||
# clearing history for groupchat here
|
||||
groupchat.messages.clear()
|
||||
# clearing history for agents
|
||||
|
||||
6
autogen/cache/cache.py
vendored
6
autogen/cache/cache.py
vendored
@ -30,7 +30,7 @@ class Cache:
|
||||
ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
|
||||
|
||||
@staticmethod
|
||||
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> Cache:
|
||||
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache":
|
||||
"""
|
||||
Create a Redis cache instance.
|
||||
|
||||
@ -44,7 +44,7 @@ class Cache:
|
||||
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})
|
||||
|
||||
@staticmethod
|
||||
def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> Cache:
|
||||
def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> "Cache":
|
||||
"""
|
||||
Create a Disk cache instance.
|
||||
|
||||
@ -81,7 +81,7 @@ class Cache:
|
||||
self.config.get("cache_path_root", None),
|
||||
)
|
||||
|
||||
def __enter__(self) -> AbstractCache:
|
||||
def __enter__(self) -> "Cache":
|
||||
"""
|
||||
Enter the runtime context related to the cache object.
|
||||
|
||||
|
||||
8
autogen/io/__init__.py
Normal file
8
autogen/io/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .base import InputStream, IOStream, OutputStream
|
||||
from .console import IOConsole
|
||||
from .websockets import IOWebsockets
|
||||
|
||||
# Set the default input/output stream to the console
|
||||
IOStream._default_io_stream.set(IOConsole())
|
||||
|
||||
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")
|
||||
73
autogen/io/base.py
Normal file
73
autogen/io/base.py
Normal file
@ -0,0 +1,73 @@
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
|
||||
|
||||
__all__ = ("OutputStream", "InputStream", "IOStream")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OutputStream(Protocol):
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class InputStream(Protocol):
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IOStream(InputStream, OutputStream, Protocol):
|
||||
"""A protocol for input/output streams."""
|
||||
|
||||
@staticmethod
|
||||
def get_default() -> "IOStream":
|
||||
"""Get the default input/output stream.
|
||||
|
||||
Returns:
|
||||
IOStream: The default input/output stream.
|
||||
"""
|
||||
iostream = IOStream._default_io_stream.get()
|
||||
if iostream is None:
|
||||
raise RuntimeError("No default IOStream has been set")
|
||||
return iostream
|
||||
|
||||
# ContextVar must be used in multithreaded or async environments
|
||||
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
|
||||
_default_io_stream.set(None)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
|
||||
"""Set the default input/output stream.
|
||||
|
||||
Args:
|
||||
stream (IOStream): The input/output stream to set as the default.
|
||||
"""
|
||||
global _default_io_stream
|
||||
try:
|
||||
token = IOStream._default_io_stream.set(stream)
|
||||
yield
|
||||
finally:
|
||||
IOStream._default_io_stream.reset(token)
|
||||
|
||||
return
|
||||
37
autogen/io/console.py
Normal file
37
autogen/io/console.py
Normal file
@ -0,0 +1,37 @@
|
||||
import getpass
|
||||
from typing import Any
|
||||
|
||||
from .base import IOStream
|
||||
|
||||
__all__ = ("IOConsole",)
|
||||
|
||||
|
||||
class IOConsole(IOStream):
|
||||
"""A console input/output stream."""
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
print(*objects, sep=sep, end=end, flush=flush)
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
|
||||
if password:
|
||||
return getpass.getpass(prompt if prompt != "" else "Password: ")
|
||||
return input(prompt)
|
||||
207
autogen/io/websockets.py
Normal file
207
autogen/io/websockets.py
Normal file
@ -0,0 +1,207 @@
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, TYPE_CHECKING, Protocol, Union
|
||||
|
||||
from .base import IOStream
|
||||
|
||||
# Check if the websockets module is available
|
||||
try:
|
||||
from websockets.sync.server import serve as ws_serve
|
||||
except ImportError as e:
|
||||
_import_error: Optional[ImportError] = e
|
||||
else:
|
||||
_import_error = None
|
||||
|
||||
|
||||
__all__ = ("IOWebsockets",)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# The following type and protocols are used to define the ServerConnection and WebSocketServer classes
|
||||
# if websockets is not installed, they would be untyped
|
||||
Data = Union[str, bytes]
|
||||
|
||||
|
||||
class ServerConnection(Protocol):
|
||||
def send(self, message: Union[Data, Iterable[Data]]) -> None:
|
||||
"""Send a message to the client.
|
||||
|
||||
Args:
|
||||
message (Union[Data, Iterable[Data]]): The message to send.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def recv(self, timeout: Optional[float] = None) -> Data:
|
||||
"""Receive a message from the client.
|
||||
|
||||
Args:
|
||||
timeout (Optional[float], optional): The timeout for the receive operation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Data: The message received from the client.
|
||||
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
...
|
||||
|
||||
|
||||
class WebSocketServer(Protocol):
|
||||
def serve_forever(self) -> None:
|
||||
"""Run the server forever."""
|
||||
... # pragma: no cover
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the server."""
|
||||
... # pragma: no cover
|
||||
|
||||
def __enter__(self) -> "WebSocketServer":
|
||||
"""Enter the server context."""
|
||||
... # pragma: no cover
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
"""Exit the server context."""
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
class IOWebsockets(IOStream):
|
||||
"""A websocket input/output stream."""
|
||||
|
||||
def __init__(self, websocket: ServerConnection) -> None:
|
||||
"""Initialize the websocket input/output stream.
|
||||
|
||||
Args:
|
||||
websocket (ServerConnection): The websocket server.
|
||||
|
||||
Raises:
|
||||
ImportError: If the websockets module is not available.
|
||||
"""
|
||||
if _import_error is not None:
|
||||
raise _import_error # pragma: no cover
|
||||
|
||||
self._websocket = websocket
|
||||
|
||||
@staticmethod
|
||||
def _handler(websocket: ServerConnection, on_connect: Callable[["IOWebsockets"], None]) -> None:
|
||||
"""The handler function for the websocket server."""
|
||||
logger.info(f" - IOWebsockets._handler(): Client connected on {websocket}")
|
||||
# create a new IOWebsockets instance using the websocket that is create when a client connects
|
||||
try:
|
||||
iowebsocket = IOWebsockets(websocket)
|
||||
with IOStream.set_default(iowebsocket):
|
||||
# call the on_connect function
|
||||
try:
|
||||
on_connect(iowebsocket)
|
||||
except Exception as e:
|
||||
logger.warning(f" - IOWebsockets._handler(): Error in on_connect: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f" - IOWebsockets._handler(): Unexpected error in IOWebsockets: {e}")
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def run_server_in_thread(
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
on_connect: Callable[["IOWebsockets"], None],
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
"""Factory function to create a websocket input/output stream.
|
||||
|
||||
Args:
|
||||
host (str, optional): The host to bind the server to. Defaults to "127.0.0.1".
|
||||
port (int, optional): The port to bind the server to. Defaults to 8765.
|
||||
on_connect (Callable[[IOWebsockets], None]): The function to be executed on client connection. Typically creates agents and initiate chat.
|
||||
ssl_context (Optional[ssl.SSLContext], optional): The SSL context to use for secure connections. Defaults to None.
|
||||
kwargs (Any): Additional keyword arguments to pass to the websocket server.
|
||||
|
||||
Yields:
|
||||
str: The URI of the websocket server.
|
||||
"""
|
||||
server_dict: Dict[str, WebSocketServer] = {}
|
||||
|
||||
def _run_server() -> None:
|
||||
if _import_error is not None:
|
||||
raise _import_error
|
||||
|
||||
# print(f" - _run_server(): starting server on ws://{host}:{port}", flush=True)
|
||||
with ws_serve(
|
||||
handler=partial(IOWebsockets._handler, on_connect=on_connect),
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_context=ssl_context,
|
||||
**kwargs,
|
||||
) as server:
|
||||
# print(f" - _run_server(): server {server} started on ws://{host}:{port}", flush=True)
|
||||
|
||||
server_dict["server"] = server
|
||||
|
||||
# runs until the server is shutdown
|
||||
server.serve_forever()
|
||||
|
||||
return
|
||||
|
||||
# start server in a separate thread
|
||||
thread = threading.Thread(target=_run_server)
|
||||
thread.start()
|
||||
try:
|
||||
while "server" not in server_dict:
|
||||
sleep(0.1)
|
||||
|
||||
yield f"ws://{host}:{port}"
|
||||
|
||||
finally:
|
||||
# print(f" - run_server_in_thread(): shutting down server on ws://{host}:{port}", flush=True)
|
||||
# gracefully stop server
|
||||
if "server" in server_dict:
|
||||
# print(f" - run_server_in_thread(): shutting down server {server_dict['server']}", flush=True)
|
||||
server_dict["server"].shutdown()
|
||||
|
||||
# wait for the thread to stop
|
||||
if thread:
|
||||
thread.join()
|
||||
|
||||
@property
|
||||
def websocket(self) -> "ServerConnection":
|
||||
"""The URI of the websocket server."""
|
||||
return self._websocket
|
||||
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
"""Print data to the output stream.
|
||||
|
||||
Args:
|
||||
objects (any): The data to print.
|
||||
sep (str, optional): The separator between objects. Defaults to " ".
|
||||
end (str, optional): The end of the output. Defaults to "\n".
|
||||
flush (bool, optional): Whether to flush the output. Defaults to False.
|
||||
"""
|
||||
xs = sep.join(map(str, objects)) + end
|
||||
self._websocket.send(xs)
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
"""Read a line from the input stream.
|
||||
|
||||
Args:
|
||||
prompt (str, optional): The prompt to display. Defaults to "".
|
||||
password (bool, optional): Whether to read a password. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: The line read from the input stream.
|
||||
|
||||
"""
|
||||
if prompt != "":
|
||||
self._websocket.send(prompt)
|
||||
|
||||
msg = self._websocket.recv()
|
||||
|
||||
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
|
||||
@ -11,6 +11,7 @@ from pydantic import BaseModel
|
||||
from typing import Protocol
|
||||
|
||||
from autogen.cache.cache import Cache
|
||||
from autogen.io.base import IOStream
|
||||
from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
|
||||
from autogen.token_count_utils import count_token
|
||||
|
||||
@ -156,6 +157,8 @@ class OpenAIClient:
|
||||
Returns:
|
||||
The completion.
|
||||
"""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
|
||||
# If streaming is enabled and has messages, then iterate over the chunks of the response.
|
||||
if params.get("stream", False) and "messages" in params:
|
||||
@ -164,7 +167,7 @@ class OpenAIClient:
|
||||
completion_tokens = 0
|
||||
|
||||
# Set the terminal text color to green
|
||||
print("\033[32m", end="")
|
||||
iostream.print("\033[32m", end="")
|
||||
|
||||
# Prepare for potential function call
|
||||
full_function_call: Optional[Dict[str, Any]] = None
|
||||
@ -216,15 +219,15 @@ class OpenAIClient:
|
||||
|
||||
# If content is present, print it to the terminal and update response variables
|
||||
if content is not None:
|
||||
print(content, end="", flush=True)
|
||||
iostream.print(content, end="", flush=True)
|
||||
response_contents[choice.index] += content
|
||||
completion_tokens += 1
|
||||
else:
|
||||
# print()
|
||||
# iostream.print()
|
||||
pass
|
||||
|
||||
# Reset the terminal text color
|
||||
print("\033[0m\n")
|
||||
iostream.print("\033[0m\n")
|
||||
|
||||
# Prepare the final ChatCompletion object based on the accumulated data
|
||||
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
@ -825,25 +828,26 @@ class OpenAIWrapper:
|
||||
|
||||
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
|
||||
"""Print the usage summary."""
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
|
||||
word_from_type = "including" if usage_type == "total" else "excluding"
|
||||
if usage_summary is None:
|
||||
print("No actual cost incurred (all completions are using cache).", flush=True)
|
||||
iostream.print("No actual cost incurred (all completions are using cache).", flush=True)
|
||||
return
|
||||
|
||||
print(f"Usage summary {word_from_type} cached usage: ", flush=True)
|
||||
print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
|
||||
iostream.print(f"Usage summary {word_from_type} cached usage: ", flush=True)
|
||||
iostream.print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
|
||||
for model, counts in usage_summary.items():
|
||||
if model == "total_cost":
|
||||
continue #
|
||||
print(
|
||||
iostream.print(
|
||||
f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if self.total_usage_summary is None:
|
||||
print('No usage summary. Please call "create" first.', flush=True)
|
||||
iostream.print('No usage summary. Please call "create" first.', flush=True)
|
||||
return
|
||||
|
||||
if isinstance(mode, list):
|
||||
@ -856,14 +860,14 @@ class OpenAIWrapper:
|
||||
elif "total" in mode:
|
||||
mode = "total"
|
||||
|
||||
print("-" * 100, flush=True)
|
||||
iostream.print("-" * 100, flush=True)
|
||||
if mode == "both":
|
||||
print_usage(self.actual_usage_summary, "actual")
|
||||
print()
|
||||
iostream.print()
|
||||
if self.total_usage_summary != self.actual_usage_summary:
|
||||
print_usage(self.total_usage_summary, "total")
|
||||
else:
|
||||
print(
|
||||
iostream.print(
|
||||
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
|
||||
flush=True,
|
||||
)
|
||||
@ -873,7 +877,7 @@ class OpenAIWrapper:
|
||||
print_usage(self.actual_usage_summary, "actual")
|
||||
else:
|
||||
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
|
||||
print("-" * 100, flush=True)
|
||||
iostream.print("-" * 100, flush=True)
|
||||
|
||||
def clear_usage_summary(self) -> None:
|
||||
"""Clear the usage summary."""
|
||||
|
||||
572
notebook/agentchat_websockets.ipynb
Normal file
572
notebook/agentchat_websockets.ipynb
Normal file
@ -0,0 +1,572 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "9a71fa36",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/microsoft/autogen/blob/main/notebook/agentchat_websockets.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
||||
"\n",
|
||||
"# Websockets: Streaming input and output using websockets\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to use the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class to stream both input and output using websockets. The use of websockets allows you to build web clients that are more responsive than the one using web methods. The main difference is that the webosockets allows you to push data while you need to poll the server for new response using web mothods.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"In this guide, we explore the capabilities of the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class. It is specifically designed to enhance the development of clients such as web clients which use websockets for streaming both input and output. The [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) stands out by enabling a more dynamic and interactive user experience for web applications.\n",
|
||||
"\n",
|
||||
"Websockets technology is at the core of this functionality, offering a significant advancement over traditional web methods by allowing data to be \"pushed\" to the client in real-time. This is a departure from the conventional approach where clients must repeatedly \"poll\" the server to check for any new responses. By employing the underlining [websockets](https://websockets.readthedocs.io/) library, the IOStream class facilitates a continuous, two-way communication channel between the server and client. This ensures that updates are received instantly, without the need for constant polling, thereby making web clients more efficient and responsive.\n",
|
||||
"\n",
|
||||
"The real power of websockets, leveraged through the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class, lies in its ability to create highly responsive web clients. This responsiveness is critical for applications requiring real-time data updates such as chat applications. By integrating the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class into your web application, you not only enhance user experience through immediate data transmission but also reduce the load on your server by eliminating unnecessary polling.\n",
|
||||
"\n",
|
||||
"In essence, the transition to using websockets through the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class marks a significant enhancement in web client development. This approach not only streamlines the data exchange process between clients and servers but also opens up new possibilities for creating more interactive and engaging web applications. By following this guide, developers can harness the full potential of websockets and the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class to push the boundaries of what is possible with web client responsiveness and interactivity.\n",
|
||||
"\n",
|
||||
"## Requirements\n",
|
||||
"\n",
|
||||
"````{=mdx}\n",
|
||||
":::info Requirements\n",
|
||||
"Some extra dependencies are needed for this notebook, which can be installed via pip:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install pyautogen[websockets] fastapi uvicorn\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For more information, please refer to the [installation guide](/docs/installation/).\n",
|
||||
":::\n",
|
||||
"````"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "5ebd2397",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set your API Endpoint\n",
|
||||
"\n",
|
||||
"The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "dca301a4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"gpt-4\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from tempfile import TemporaryDirectory\n",
|
||||
"\n",
|
||||
"from websockets.sync.client import connect as ws_connect\n",
|
||||
"\n",
|
||||
"import autogen\n",
|
||||
"from autogen.cache import Cache\n",
|
||||
"from autogen.io.websockets import IOStream, IOWebsockets\n",
|
||||
"\n",
|
||||
"config_list = autogen.config_list_from_json(\n",
|
||||
" \"OAI_CONFIG_LIST\",\n",
|
||||
" filter_dict={\n",
|
||||
" \"model\": [\"gpt-4\", \"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(config_list[0][\"model\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "92fde41f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"````{=mdx}\n",
|
||||
":::tip\n",
|
||||
"Learn more about configuring LLMs for agents [here](/docs/topics/llm_configuration).\n",
|
||||
":::\n",
|
||||
"````"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "2b9526e7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Defining `on_connect` function\n",
|
||||
"\n",
|
||||
"An `on_connect` function is a crucial part of applications that utilize websockets, acting as an event handler that is called whenever a new client connection is established. This function is designed to initiate any necessary setup, communication protocols, or data exchange procedures specific to the newly connected client. Essentially, it lays the groundwork for the interactive session that follows, configuring how the server and the client will communicate and what initial actions are to be taken once a connection is made. Now, let's delve into the details of how to define this function, especially in the context of using the AutoGen framework with websockets.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Upon a client's connection to the websocket server, the server automatically initiates a new instance of the [`IOWebsockets`](https://microsoft.github.io/autogen/docs/reference/io/websockets/IOWebsockets) class. This instance is crucial for managing the data flow between the server and the client. The `on_connect` function leverages this instance to set up the communication protocol, define interaction rules, and initiate any preliminary data exchanges or configurations required for the client-server interaction to proceed smoothly.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "9fb85afb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def on_connect(iostream: IOWebsockets) -> None:\n",
|
||||
" print(f\" - on_connect(): Connected to client using IOWebsockets {iostream}\", flush=True)\n",
|
||||
"\n",
|
||||
" print(\" - on_connect(): Receiving message from client.\", flush=True)\n",
|
||||
"\n",
|
||||
" initial_msg = iostream.input()\n",
|
||||
"\n",
|
||||
" llm_config = {\n",
|
||||
" \"config_list\": config_list,\n",
|
||||
" \"stream\": True,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" agent = autogen.ConversableAgent(\n",
|
||||
" name=\"chatbot\",\n",
|
||||
" system_message=\"Complete a task given to you and reply TERMINATE when the task is done. If asked about the weather, use tool weather_forecast(city) to get the weather forecast for a city.\",\n",
|
||||
" llm_config=llm_config,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # create a UserProxyAgent instance named \"user_proxy\"\n",
|
||||
" user_proxy = autogen.UserProxyAgent(\n",
|
||||
" name=\"user_proxy\",\n",
|
||||
" system_message=\"A proxy for the user.\",\n",
|
||||
" is_termination_msg=lambda x: x.get(\"content\", \"\") and x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" max_consecutive_auto_reply=10,\n",
|
||||
" code_execution_config=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" @user_proxy.register_for_execution()\n",
|
||||
" @agent.register_for_llm(description=\"Weather forecats for a city\")\n",
|
||||
" def weather_forecast(city: str) -> str:\n",
|
||||
" return f\"The weather forecast for {city} is sunny.\"\n",
|
||||
"\n",
|
||||
" # we will use a temporary directory as the cache path root to ensure fresh completion each time\n",
|
||||
" with TemporaryDirectory() as cache_path_root:\n",
|
||||
" with Cache.disk(cache_path_root=cache_path_root) as cache:\n",
|
||||
" print(\n",
|
||||
" f\" - on_connect(): Initiating chat with agent {agent} using message '{initial_msg}'\",\n",
|
||||
" flush=True,\n",
|
||||
" )\n",
|
||||
" user_proxy.initiate_chat( # noqa: F704\n",
|
||||
" agent,\n",
|
||||
" message=initial_msg,\n",
|
||||
" cache=cache,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a1124796",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here's an explanation on how a typical `on_connect` function such as the one in the example above is defined:\n",
|
||||
"\n",
|
||||
"1. **Receiving Initial Message**: Immediately after establishing a connection, receive an initial message from the client. This step is crucial for understanding the client's request or initiating the conversation flow.\n",
|
||||
"\n",
|
||||
"2. **Receiving Initial Message**: Immediately after establishing a connection, receive an initial message from the client. This step is crucial for understanding the client's request or initiating the conversation flow.\n",
|
||||
"\n",
|
||||
"3. **Configure the LLM**: Define the configuration for your large language model (LLM), specifying the list of configurations and the streaming capability. This configuration will be used to tailor the behavior of your conversational agent.\n",
|
||||
"\n",
|
||||
"4. **Instantiate ConversableAgent and UserProxyAgent**: Create an instance of ConversableAgent with a specific system message and the LLM configuration. Similarly, create a UserProxyAgent instance, defining its termination condition, human input mode, and other relevant parameters.\n",
|
||||
"\n",
|
||||
"5. **Define Agent-specific Functions**: If your conversable agent requires executing specific tasks, such as fetching a weather forecast in the example below, define these functions within the on_connect scope. Decorate these functions accordingly to link them with your agents.\n",
|
||||
"\n",
|
||||
"5. **Initiate Conversation**: Finally, use the `initiate_chat` method of your `UserProxyAgent` to start the interaction with the conversable agent, passing the initial message and a cache mechanism for efficiency."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62ef868a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Testing websockets server with Python client\n",
|
||||
"\n",
|
||||
"Testing an `on_connect` function with a Python client involves simulating a client-server interaction to ensure the setup, data exchange, and communication protocols function as intended. Here’s a brief explanation on how to conduct this test using a Python client:\n",
|
||||
"\n",
|
||||
"1. **Start the Websocket Server**: Use the `IOWebsockets.run_server_in_thread method` to start the server in a separate thread, specifying the on_connect function and the port. This method returns the URI of the running websocket server.\n",
|
||||
"\n",
|
||||
"2. **Connect to the Server**: Open a connection to the server using the returned URI. This simulates a client initiating a connection to your websocket server.\n",
|
||||
"\n",
|
||||
"3. **Send a Message to the Server**: Once connected, send a message from the client to the server. This tests the server's ability to receive messages through the established websocket connection.\n",
|
||||
"\n",
|
||||
"4. **Receive and Process Messages**: Implement a loop to continuously receive messages from the server. Decode the messages if necessary, and process them accordingly. This step verifies the server's ability to respond back to the client's request.\n",
|
||||
"\n",
|
||||
"This test scenario effectively evaluates the interaction between a client and a server using the `on_connect` function, by simulating a realistic message exchange. It ensures that the server can handle incoming connections, process messages, and communicate responses back to the client, all critical functionalities for a robust websocket-based application."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "4fbe004d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" - test_setup() with websocket server running on ws://127.0.0.1:8765.\n",
|
||||
" - on_connect(): Connected to client using IOWebsockets <autogen.io.websockets.IOWebsockets object at 0x75ad84aa0d60>\n",
|
||||
" - on_connect(): Receiving message from client.\n",
|
||||
" - Connected to server on ws://127.0.0.1:8765\n",
|
||||
" - Sending message to server.\n",
|
||||
" - on_connect(): Initiating chat with agent <autogen.agentchat.conversable_agent.ConversableAgent object at 0x75ad84a72b30> using message 'Check out the weather in Paris and write a poem about it.'\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"Check out the weather in Paris and write a poem about it.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[32m\u001b[32m\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Suggested tool Call (call_U5VR0hck9KhDFWPdvmo1Eoke): weather_forecast *****\u001b[0m\n",
|
||||
"Arguments: \n",
|
||||
"{\n",
|
||||
" \"city\": \"Paris\"\n",
|
||||
"}\n",
|
||||
"\u001b[32m*********************************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[35m\n",
|
||||
">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
|
||||
"\n",
|
||||
"\u001b[32m***** Response from calling tool \"call_U5VR0hck9KhDFWPdvmo1Eoke\" *****\u001b[0m\n",
|
||||
"The weather forecast for Paris is sunny.\n",
|
||||
"\u001b[32m**********************************************************************\u001b[0m\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[31m\n",
|
||||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[32m\u001b[32mIn the city of love, shines the sun above,\n",
|
||||
"Paris basks in golden rays, a beautiful day to praise.\n",
|
||||
"Strolling down the Champs Elysées, the warm light leads the way,\n",
|
||||
"In the glow, silhouettes dance, a perfect setting for romance.\n",
|
||||
"\n",
|
||||
"In the sunlight, the Seine sparkles bright, reflecting the City of Light,\n",
|
||||
"Not a cloud in the crystal-clear blue sky, as the doves sail high.\n",
|
||||
"Sunny Paris so profound, beauty all around,\n",
|
||||
"Alive under the radiant crown, she wears her sunlight like a gown.\n",
|
||||
"\n",
|
||||
"TERMINATE\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"In the city of love, shines the sun above,\n",
|
||||
"Paris basks in golden rays, a beautiful day to praise.\n",
|
||||
"Strolling down the Champs Elysées, the warm light leads the way,\n",
|
||||
"In the glow, silhouettes dance, a perfect setting for romance.\n",
|
||||
"\n",
|
||||
"In the sunlight, the Seine sparkles bright, reflecting the City of Light,\n",
|
||||
"Not a cloud in the crystal-clear blue sky, as the doves sail high.\n",
|
||||
"Sunny Paris so profound, beauty all around,\n",
|
||||
"Alive under the radiant crown, she wears her sunlight like a gown.\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
" - Received TERMINATE message. Exiting.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:\n",
|
||||
" print(f\" - test_setup() with websocket server running on {uri}.\", flush=True)\n",
|
||||
"\n",
|
||||
" with ws_connect(uri) as websocket:\n",
|
||||
" print(f\" - Connected to server on {uri}\", flush=True)\n",
|
||||
"\n",
|
||||
" print(\" - Sending message to server.\", flush=True)\n",
|
||||
" # websocket.send(\"2+2=?\")\n",
|
||||
" websocket.send(\"Check out the weather in Paris and write a poem about it.\")\n",
|
||||
"\n",
|
||||
" while True:\n",
|
||||
" message = websocket.recv()\n",
|
||||
" message = message.decode(\"utf-8\") if isinstance(message, bytes) else message\n",
|
||||
"\n",
|
||||
" print(message, end=\"\", flush=True)\n",
|
||||
"\n",
|
||||
" if \"TERMINATE\" in message:\n",
|
||||
" print()\n",
|
||||
" print(\" - Received TERMINATE message. Exiting.\", flush=True)\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3a656564",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Testing websockets server running inside FastAPI server with HTML/JS client\n",
|
||||
"\n",
|
||||
"The code snippets below outlines an approach for testing an `on_connect` function in a web environment using [FastAPI](https://fastapi.tiangolo.com/) to serve a simple interactive HTML page. This method allows users to send messages through a web interface, which are then processed by the server running the AutoGen framework via websockets. Here's a step-by-step explanation:\n",
|
||||
"\n",
|
||||
"1. **FastAPI Application Setup**: The code initiates by importing necessary libraries and setting up a FastAPI application. FastAPI is a modern, fast web framework for building APIs with Python 3.7+ based on standard Python type hints.\n",
|
||||
"\n",
|
||||
"2. **HTML Template for User Interaction**: An HTML template is defined as a multi-line Python string, which includes a basic form for message input and a script for managing websocket communication. This template creates a user interface where messages can be sent to the server and responses are displayed dynamically.\n",
|
||||
"\n",
|
||||
"3. **Running the Websocket Server**: The `run_websocket_server` async context manager starts the websocket server using `IOWebsockets.run_server_in_thread` with the specified `on_connect` function and port. This server listens for incoming websocket connections.\n",
|
||||
"\n",
|
||||
"4. **FastAPI Route for Serving HTML Page**: A FastAPI route (`@app.get(\"/\")`) is defined to serve the HTML page to users. When a user accesses the root URL, the HTML content for the websocket chat is returned, allowing them to interact with the websocket server.\n",
|
||||
"\n",
|
||||
"5. **Starting the FastAPI Application**: Lastly, the FastAPI application is started using Uvicorn, an ASGI server, configured with the app and additional parameters as needed. The server is then launched to serve the FastAPI application, making the interactive HTML page accessible to users.\n",
|
||||
"\n",
|
||||
"This method of testing allows for interactive communication between the user and the server, providing a practical way to demonstrate and evaluate the behavior of the on_connect function in real-time. Users can send messages through the webpage, and the server processes these messages as per the logic defined in the on_connect function, showcasing the capabilities and responsiveness of the AutoGen framework's websocket handling in a user-friendly manner."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "5e55dc06",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'fastapi'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[4], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcontextlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m asynccontextmanager \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpathlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Path \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastapi\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FastAPI \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastapi\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mresponses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m HTMLResponse \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m 7\u001b[0m PORT \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m8000\u001b[39m\n",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'fastapi'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from contextlib import asynccontextmanager # noqa: E402\n",
|
||||
"from pathlib import Path # noqa: E402\n",
|
||||
"\n",
|
||||
"from fastapi import FastAPI # noqa: E402\n",
|
||||
"from fastapi.responses import HTMLResponse # noqa: E402\n",
|
||||
"\n",
|
||||
"PORT = 8000\n",
|
||||
"\n",
|
||||
"html = \"\"\"\n",
|
||||
"<!DOCTYPE html>\n",
|
||||
"<html>\n",
|
||||
" <head>\n",
|
||||
" <title>Autogen websocket test</title>\n",
|
||||
" </head>\n",
|
||||
" <body>\n",
|
||||
" <h1>WebSocket Chat</h1>\n",
|
||||
" <form action=\"\" onsubmit=\"sendMessage(event)\">\n",
|
||||
" <input type=\"text\" id=\"messageText\" autocomplete=\"off\"/>\n",
|
||||
" <button>Send</button>\n",
|
||||
" </form>\n",
|
||||
" <ul id='messages'>\n",
|
||||
" </ul>\n",
|
||||
" <script>\n",
|
||||
" var ws = new WebSocket(\"ws://localhost:8080/ws\");\n",
|
||||
" ws.onmessage = function(event) {\n",
|
||||
" var messages = document.getElementById('messages')\n",
|
||||
" var message = document.createElement('li')\n",
|
||||
" var content = document.createTextNode(event.data)\n",
|
||||
" message.appendChild(content)\n",
|
||||
" messages.appendChild(message)\n",
|
||||
" };\n",
|
||||
" function sendMessage(event) {\n",
|
||||
" var input = document.getElementById(\"messageText\")\n",
|
||||
" ws.send(input.value)\n",
|
||||
" input.value = ''\n",
|
||||
" event.preventDefault()\n",
|
||||
" }\n",
|
||||
" </script>\n",
|
||||
" </body>\n",
|
||||
"</html>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@asynccontextmanager\n",
|
||||
"async def run_websocket_server(app):\n",
|
||||
" with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8080) as uri:\n",
|
||||
" print(f\"Websocket server started at {uri}.\", flush=True)\n",
|
||||
"\n",
|
||||
" yield\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"app = FastAPI(lifespan=run_websocket_server)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@app.get(\"/\")\n",
|
||||
"async def get():\n",
|
||||
" return HTMLResponse(html)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d92e50b5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import uvicorn # noqa: E402\n",
|
||||
"\n",
|
||||
"config = uvicorn.Config(app)\n",
|
||||
"server = uvicorn.Server(config)\n",
|
||||
"await server.serve() # noqa: F704"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c8c9f61",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The testing setup described above, leveraging FastAPI and websockets, not only serves as a robust testing framework for the on_connect function but also lays the groundwork for developing real-world applications. This approach exemplifies how web-based interactions can be made dynamic and real-time, a critical aspect of modern application development.\n",
|
||||
"\n",
|
||||
"For instance, this setup can be directly applied or adapted to build interactive chat applications, real-time data dashboards, or live support systems. The integration of websockets enables the server to push updates to clients instantly, a key feature for applications that rely on the timely delivery of information. For example, a chat application built on this framework can support instantaneous messaging between users, enhancing user engagement and satisfaction.\n",
|
||||
"\n",
|
||||
"Moreover, the simplicity and interactivity of the HTML page used for testing reflect how user interfaces can be designed to provide seamless experiences. Developers can expand upon this foundation to incorporate more sophisticated elements such as user authentication, message encryption, and custom user interactions, further tailoring the application to meet specific use case requirements.\n",
|
||||
"\n",
|
||||
"The flexibility of the FastAPI framework, combined with the real-time communication enabled by websockets, provides a powerful toolset for developers looking to build scalable, efficient, and highly interactive web applications. Whether it's for creating collaborative platforms, streaming services, or interactive gaming experiences, this testing setup offers a glimpse into the potential applications that can be developed with these technologies."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cfb50946",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Testing websockets server with HTML/JS client\n",
|
||||
"\n",
|
||||
"The provided code snippet below is an example of how to create an interactive testing environment for an `on_connect` function using Python's built-in `http.server` module. This setup allows for real-time interaction within a web browser, enabling developers to test the websocket functionality in a more user-friendly and practical manner. Here's a breakdown of how this code operates and its potential applications:\n",
|
||||
"\n",
|
||||
"1. **Serving a Simple HTML Page**: The code starts by defining an HTML page that includes a form for sending messages and a list to display incoming messages. JavaScript is used to handle the form submission and websocket communication.\n",
|
||||
"\n",
|
||||
"2. **Temporary Directory for HTML File**: A temporary directory is created to store the HTML file. This approach ensures that the testing environment is clean and isolated, minimizing conflicts with existing files or configurations.\n",
|
||||
"\n",
|
||||
"3. **Custom HTTP Request Handler**: A custom subclass of `SimpleHTTPRequestHandler` is defined to serve the HTML file. This handler overrides the do_GET method to redirect the root path (`/`) to the `chat.html` page, ensuring that visitors to the server's root URL are immediately presented with the chat interface.\n",
|
||||
"\n",
|
||||
"4. **Starting the Websocket Server**: Concurrently, a websocket server is started on a different port using the `IOWebsockets.run_server_in_thread` method, with the previously defined `on_connect` function as the callback for new connections.\n",
|
||||
"\n",
|
||||
"5. **HTTP Server for the HTML Interface**: An HTTP server is instantiated to serve the HTML chat interface, enabling users to interact with the websocket server through a web browser.\n",
|
||||
"\n",
|
||||
"This setup showcases a practical application of integrating websockets with a simple HTTP server to create a dynamic and interactive web application. By using Python's standard library modules, it demonstrates a low-barrier entry to developing real-time applications such as chat systems, live notifications, or interactive dashboards.\n",
|
||||
"\n",
|
||||
"The key takeaway from this code example is how easily Python's built-in libraries can be leveraged to prototype and test complex web functionalities. For developers looking to build real-world applications, this approach offers a straightforward method to validate and refine websocket communication logic before integrating it into larger frameworks or systems. The simplicity and accessibility of this testing setup make it an excellent starting point for developing a wide range of interactive web applications.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "708a98de",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from http.server import HTTPServer, SimpleHTTPRequestHandler # noqa: E402\n",
|
||||
"\n",
|
||||
"PORT = 8000\n",
|
||||
"\n",
|
||||
"html = \"\"\"\n",
|
||||
"<!DOCTYPE html>\n",
|
||||
"<html>\n",
|
||||
" <head>\n",
|
||||
" <title>Autogen websocket test</title>\n",
|
||||
" </head>\n",
|
||||
" <body>\n",
|
||||
" <h1>WebSocket Chat</h1>\n",
|
||||
" <form action=\"\" onsubmit=\"sendMessage(event)\">\n",
|
||||
" <input type=\"text\" id=\"messageText\" autocomplete=\"off\"/>\n",
|
||||
" <button>Send</button>\n",
|
||||
" </form>\n",
|
||||
" <ul id='messages'>\n",
|
||||
" </ul>\n",
|
||||
" <script>\n",
|
||||
" var ws = new WebSocket(\"ws://localhost:8080/ws\");\n",
|
||||
" ws.onmessage = function(event) {\n",
|
||||
" var messages = document.getElementById('messages')\n",
|
||||
" var message = document.createElement('li')\n",
|
||||
" var content = document.createTextNode(event.data)\n",
|
||||
" message.appendChild(content)\n",
|
||||
" messages.appendChild(message)\n",
|
||||
" };\n",
|
||||
" function sendMessage(event) {\n",
|
||||
" var input = document.getElementById(\"messageText\")\n",
|
||||
" ws.send(input.value)\n",
|
||||
" input.value = ''\n",
|
||||
" event.preventDefault()\n",
|
||||
" }\n",
|
||||
" </script>\n",
|
||||
" </body>\n",
|
||||
"</html>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"with TemporaryDirectory() as temp_dir:\n",
|
||||
" # create a simple HTTP webpage\n",
|
||||
" path = Path(temp_dir) / \"chat.html\"\n",
|
||||
" with open(path, \"w\") as f:\n",
|
||||
" f.write(html)\n",
|
||||
"\n",
|
||||
" #\n",
|
||||
" class MyRequestHandler(SimpleHTTPRequestHandler):\n",
|
||||
" def __init__(self, *args, **kwargs):\n",
|
||||
" super().__init__(*args, directory=temp_dir, **kwargs)\n",
|
||||
"\n",
|
||||
" def do_GET(self):\n",
|
||||
" if self.path == \"/\":\n",
|
||||
" self.path = \"/chat.html\"\n",
|
||||
" return SimpleHTTPRequestHandler.do_GET(self)\n",
|
||||
"\n",
|
||||
" handler = MyRequestHandler\n",
|
||||
"\n",
|
||||
" with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8080) as uri:\n",
|
||||
" print(f\"Websocket server started at {uri}.\", flush=True)\n",
|
||||
"\n",
|
||||
" with HTTPServer((\"\", PORT), handler) as httpd:\n",
|
||||
" print(\"HTTP server started at http://localhost:\" + str(PORT))\n",
|
||||
" httpd.serve_forever()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19656d0e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"front_matter": {
|
||||
"description": "Websockets facilitate real-time, bidirectional communication between web clients and servers, enhancing the responsiveness and interactivity of AutoGen-powered applications.",
|
||||
"tags": [
|
||||
"websockets",
|
||||
"streaming"
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "flaml_dev",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -67,6 +67,8 @@ files = [
|
||||
"autogen/exception_utils.py",
|
||||
"autogen/coding",
|
||||
"autogen/oai/openai_utils.py",
|
||||
"autogen/io",
|
||||
"test/io",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
|
||||
3
setup.py
3
setup.py
@ -65,8 +65,9 @@ setuptools.setup(
|
||||
"graph": ["networkx", "matplotlib"],
|
||||
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
|
||||
"redis": ["redis"],
|
||||
"websockets": ["websockets>=12.0,<13"],
|
||||
"jupyter-executor": jupyter_executor,
|
||||
"types": ["mypy==1.9.0"] + jupyter_executor,
|
||||
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
||||
@ -157,7 +157,8 @@ def test_two_agents_logging(db_connection):
|
||||
assert row["session_id"] and row["session_id"] == session_id
|
||||
assert row["class"] in ["AzureOpenAI", "OpenAI"]
|
||||
init_args = json.loads(row["init_args"])
|
||||
assert "api_version" in init_args
|
||||
if row["class"] == "AzureOpenAI":
|
||||
assert "api_version" in init_args
|
||||
assert row["timestamp"], "timestamp is empty"
|
||||
|
||||
# Verify oai wrapper table
|
||||
|
||||
28
test/io/test_base.py
Normal file
28
test/io/test_base.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Any
|
||||
|
||||
from autogen.io import IOConsole, IOStream, IOWebsockets
|
||||
|
||||
|
||||
class TestIOStream:
|
||||
def test_initial_default_io_stream(self) -> None:
|
||||
assert isinstance(IOStream.get_default(), IOConsole)
|
||||
|
||||
def test_set_default_io_stream(self) -> None:
|
||||
class MyIOStream(IOStream):
|
||||
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
|
||||
pass
|
||||
|
||||
def input(self, prompt: str = "", *, password: bool = False) -> str:
|
||||
return "Hello, World!"
|
||||
|
||||
assert isinstance(IOStream.get_default(), IOConsole)
|
||||
|
||||
with IOStream.set_default(MyIOStream()):
|
||||
assert isinstance(IOStream.get_default(), MyIOStream)
|
||||
|
||||
with IOStream.set_default(IOConsole()):
|
||||
assert isinstance(IOStream.get_default(), IOConsole)
|
||||
|
||||
assert isinstance(IOStream.get_default(), MyIOStream)
|
||||
|
||||
assert isinstance(IOStream.get_default(), IOConsole)
|
||||
38
test/io/test_console.py
Normal file
38
test/io/test_console.py
Normal file
@ -0,0 +1,38 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen.io import IOConsole
|
||||
|
||||
|
||||
class TestConsoleIO:
|
||||
def setup_method(self) -> None:
|
||||
self.console_io = IOConsole()
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_print(self, mock_print: MagicMock) -> None:
|
||||
# calling the print method should call the mock of the builtin print function
|
||||
self.console_io.print("Hello, World!", flush=True)
|
||||
mock_print.assert_called_once_with("Hello, World!", end="\n", sep=" ", flush=True)
|
||||
|
||||
@patch("builtins.input")
|
||||
def test_input(self, mock_input: MagicMock) -> None:
|
||||
# calling the input method should call the mock of the builtin input function
|
||||
mock_input.return_value = "Hello, World!"
|
||||
|
||||
actual = self.console_io.input("Hi!")
|
||||
assert actual == "Hello, World!"
|
||||
mock_input.assert_called_once_with("Hi!")
|
||||
|
||||
@pytest.mark.parametrize("prompt", ["", "Password: ", "Enter you password:"])
|
||||
def test_input_password(self, monkeypatch: pytest.MonkeyPatch, prompt: str) -> None:
|
||||
mock_getpass = MagicMock()
|
||||
mock_getpass.return_value = "123456"
|
||||
monkeypatch.setattr("getpass.getpass", mock_getpass)
|
||||
|
||||
actual = self.console_io.input(prompt, password=True)
|
||||
assert actual == "123456"
|
||||
if prompt == "":
|
||||
mock_getpass.assert_called_once_with("Password: ")
|
||||
else:
|
||||
mock_getpass.assert_called_once_with(prompt)
|
||||
176
test/io/test_websockets.py
Normal file
176
test/io/test_websockets.py
Normal file
@ -0,0 +1,176 @@
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from autogen.io.base import IOStream
|
||||
from conftest import skip_openai
|
||||
|
||||
import autogen
|
||||
from autogen.cache.cache import Cache
|
||||
from autogen.io import IOWebsockets
|
||||
|
||||
KEY_LOC = "notebook"
|
||||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
||||
|
||||
# Check if the websockets module is available
|
||||
try:
|
||||
from websockets.sync.client import connect as ws_connect
|
||||
except ImportError: # pragma: no cover
|
||||
skip_test = True
|
||||
else:
|
||||
skip_test = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_test, reason="websockets module is not available")
|
||||
class TestConsoleIOWithWebsockets:
|
||||
def test_input_print(self) -> None:
|
||||
print()
|
||||
print("Testing input/print", flush=True)
|
||||
|
||||
def on_connect(iostream: IOWebsockets) -> None:
|
||||
print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True)
|
||||
|
||||
print(" - on_connect(): Receiving message from client.", flush=True)
|
||||
|
||||
msg = iostream.input()
|
||||
|
||||
print(f" - on_connect(): Received message '{msg}' from client.", flush=True)
|
||||
|
||||
assert msg == "Hello world!"
|
||||
|
||||
for msg in ["Hello, World!", "Over and out!"]:
|
||||
print(f" - on_connect(): Sending message '{msg}' to client.", flush=True)
|
||||
|
||||
iostream.print(msg)
|
||||
|
||||
print(" - on_connect(): Receiving message from client.", flush=True)
|
||||
|
||||
msg = iostream.input("May I?")
|
||||
|
||||
print(f" - on_connect(): Received message '{msg}' from client.", flush=True)
|
||||
assert msg == "Yes"
|
||||
|
||||
return
|
||||
|
||||
with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:
|
||||
print(f" - test_setup() with websocket server running on {uri}.", flush=True)
|
||||
|
||||
with ws_connect(uri) as websocket:
|
||||
print(f" - Connected to server on {uri}", flush=True)
|
||||
|
||||
print(" - Sending message to server.", flush=True)
|
||||
websocket.send("Hello world!")
|
||||
|
||||
for expected in ["Hello, World!", "Over and out!", "May I?"]:
|
||||
print(" - Receiving message from server.", flush=True)
|
||||
message = websocket.recv()
|
||||
message = message.decode("utf-8") if isinstance(message, bytes) else message
|
||||
# drop the newline character
|
||||
if message.endswith("\n"):
|
||||
message = message[:-1]
|
||||
|
||||
print(
|
||||
f" - Asserting received message '{message}' is the same as the expected message '{expected}'",
|
||||
flush=True,
|
||||
)
|
||||
assert message == expected
|
||||
|
||||
print(" - Sending message 'Yes' to server.", flush=True)
|
||||
websocket.send("Yes")
|
||||
|
||||
print("Test passed.", flush=True)
|
||||
|
||||
@pytest.mark.skipif(skip_openai, reason="requested to skip")
|
||||
def test_chat(self) -> None:
|
||||
print("Testing setup", flush=True)
|
||||
|
||||
success_dict = {"success": False}
|
||||
|
||||
def on_connect(iostream: IOWebsockets, success_dict: Dict[str, bool] = success_dict) -> None:
|
||||
print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True)
|
||||
|
||||
print(" - on_connect(): Receiving message from client.", flush=True)
|
||||
|
||||
initial_msg = iostream.input()
|
||||
|
||||
config_list = autogen.config_list_from_json(
|
||||
OAI_CONFIG_LIST,
|
||||
filter_dict={
|
||||
"model": [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-v0314",
|
||||
],
|
||||
},
|
||||
file_location=KEY_LOC,
|
||||
)
|
||||
|
||||
llm_config = {
|
||||
"config_list": config_list,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
agent = autogen.ConversableAgent(
|
||||
name="chatbot",
|
||||
system_message="Complete a task given to you and reply TERMINATE when the task is done.",
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
# create a UserProxyAgent instance named "user_proxy"
|
||||
user_proxy = autogen.UserProxyAgent(
|
||||
name="user_proxy",
|
||||
system_message="A proxy for the user.",
|
||||
is_termination_msg=lambda x: x.get("content", "")
|
||||
and x.get("content", "").rstrip().endswith("TERMINATE"),
|
||||
human_input_mode="NEVER",
|
||||
max_consecutive_auto_reply=10,
|
||||
)
|
||||
|
||||
# we will use a temporary directory as the cache path root to ensure fresh completion each time
|
||||
with TemporaryDirectory() as cache_path_root:
|
||||
with Cache.disk(cache_path_root=cache_path_root) as cache:
|
||||
print(
|
||||
f" - on_connect(): Initiating chat with agent {agent} using message '{initial_msg}'",
|
||||
flush=True,
|
||||
)
|
||||
user_proxy.initiate_chat( # noqa: F704
|
||||
agent,
|
||||
message=initial_msg,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
success_dict["success"] = True
|
||||
|
||||
return
|
||||
|
||||
with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:
|
||||
print(f" - test_setup() with websocket server running on {uri}.", flush=True)
|
||||
|
||||
with ws_connect(uri) as websocket:
|
||||
print(f" - Connected to server on {uri}", flush=True)
|
||||
|
||||
print(" - Sending message to server.", flush=True)
|
||||
# websocket.send("2+2=?")
|
||||
websocket.send("Please write a poem about spring in a city of your choice.")
|
||||
|
||||
while True:
|
||||
message = websocket.recv()
|
||||
message = message.decode("utf-8") if isinstance(message, bytes) else message
|
||||
# drop the newline character
|
||||
if message.endswith("\n"):
|
||||
message = message[:-1]
|
||||
|
||||
print(message, end="", flush=True)
|
||||
|
||||
if "TERMINATE" in message:
|
||||
print()
|
||||
print(" - Received TERMINATE message. Exiting.", flush=True)
|
||||
break
|
||||
|
||||
assert success_dict["success"]
|
||||
print("Test passed.", flush=True)
|
||||
Loading…
x
Reference in New Issue
Block a user