Introducing IOStream protocol and adding support for websockets (#1551)

* Introducing IOStream

* bug fixing

* polishing

* refactoring

* refactoring

* refactoring

* wip: async tests

* websockets added

* wip

* merge with main

* notebook added

* FastAPI example added

* wip

* merge

* getter/setter to iostream added

* website/blog/2024-03-03-AutoGen-Update/img/dalle_gpt4v.png: convert to Git LFS

* website/blog/2024-03-03-AutoGen-Update/img/gaia.png: convert to Git LFS

* website/blog/2024-03-03-AutoGen-Update/img/teach.png: convert to Git LFS

* add SSL support

* wip

* wip

* exception handling added to on_connect()

* refactoring: default iostream is being set in a context manager

* test fix

* polishing

* polishing

* polishing

* fixed bug with new thread

* polishing

* a bit of refactoring and docs added

* notebook added to docs

* type checking added to CI

* CI fix

* CI fix

* CI fix

* polishing

* obsolete todo comment removed

* fixed precommit error

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Davor Runje 2024-03-26 23:39:55 +01:00 committed by GitHub
parent 72994ea127
commit 78aa0eb220
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1252 additions and 66 deletions

View File

@ -65,7 +65,7 @@ jobs:
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install -e .[test,redis]
pip install -e .[test,redis,websockets]
coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0
coverage xml
- name: Upload coverage to Codecov

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass
from .utils import consolidate_chat_info
import datetime
import warnings
from ..io.base import IOStream
from ..formatting_utils import colored
@ -103,6 +104,8 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
@ -122,8 +125,8 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
print_message = "Dict: " + str(message)
elif message is None:
print_message = "None"
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
iostream.print(
colored(
"Starting a new chat....",
"blue",
@ -131,9 +134,9 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
flush=True,
)
if chat_info.get("verbose", False):
print(colored("Message:\n" + print_message, "blue"), flush=True)
print(colored("Carryover:\n" + print_carryover, "blue"), flush=True)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
iostream.print(colored("Message:\n" + print_message, "blue"), flush=True)
iostream.print(colored("Carryover:\n" + print_carryover, "blue"), flush=True)
iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:

View File

@ -32,6 +32,7 @@ from ..function_utils import get_function_schema, load_basemodels_if_needed, ser
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent, LLMAgent
from ..io.base import IOStream
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary
@ -681,8 +682,9 @@ class ConversableAgent(LLMAgent):
)
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
iostream = IOStream.get_default()
# print the message received
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
message = self._message_to_dict(message)
if message.get("tool_responses"): # Handle tool multi-call responses
@ -698,9 +700,9 @@ class ConversableAgent(LLMAgent):
id_key = "tool_call_id"
id = message.get(id_key, "No id found")
func_print = f"***** Response from calling {message['role']} ({id}) *****"
print(colored(func_print, "green"), flush=True)
print(message["content"], flush=True)
print(colored("*" * len(func_print), "green"), flush=True)
iostream.print(colored(func_print, "green"), flush=True)
iostream.print(message["content"], flush=True)
iostream.print(colored("*" * len(func_print), "green"), flush=True)
else:
content = message.get("content")
if content is not None:
@ -710,35 +712,35 @@ class ConversableAgent(LLMAgent):
message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
print(content_str(content), flush=True)
iostream.print(content_str(content), flush=True)
if "function_call" in message and message["function_call"]:
function_call = dict(message["function_call"])
func_print = (
f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****"
)
print(colored(func_print, "green"), flush=True)
print(
iostream.print(colored(func_print, "green"), flush=True)
iostream.print(
"Arguments: \n",
function_call.get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print(colored("*" * len(func_print), "green"), flush=True)
iostream.print(colored("*" * len(func_print), "green"), flush=True)
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
id = tool_call.get("id", "No tool call id found")
function_call = dict(tool_call.get("function", {}))
func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****"
print(colored(func_print, "green"), flush=True)
print(
iostream.print(colored(func_print, "green"), flush=True)
iostream.print(
"Arguments: \n",
function_call.get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print(colored("*" * len(func_print), "green"), flush=True)
iostream.print(colored("*" * len(func_print), "green"), flush=True)
print("\n", "-" * 80, flush=True, sep="")
iostream.print("\n", "-" * 80, flush=True, sep="")
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
@ -1229,6 +1231,7 @@ class ConversableAgent(LLMAgent):
recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents.
nr_messages_to_preserve: the number of newest messages to preserve in the chat history.
"""
iostream = IOStream.get_default()
if recipient is None:
if nr_messages_to_preserve:
for key in self._oai_messages:
@ -1238,7 +1241,7 @@ class ConversableAgent(LLMAgent):
first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal]
if "tool_responses" in first_msg_to_save:
nr_messages_to_preserve_internal += 1
print(
iostream.print(
f"Preserving one more message for {self.name} to not divide history between tool call and "
f"tool response."
)
@ -1249,7 +1252,7 @@ class ConversableAgent(LLMAgent):
else:
self._oai_messages[recipient].clear()
if nr_messages_to_preserve:
print(
iostream.print(
colored(
"WARNING: `nr_preserved_messages` is ignored when clearing chat history with a specific agent.",
"yellow",
@ -1323,8 +1326,19 @@ class ConversableAgent(LLMAgent):
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai asynchronously."""
iostream = IOStream.get_default()
def _generate_oai_reply(
self, iostream: IOStream, *args: Any, **kwargs: Any
) -> Tuple[bool, Union[str, Dict, None]]:
with IOStream.set_default(iostream):
return self.generate_oai_reply(*args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(
None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
None,
functools.partial(
_generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config
),
)
def _generate_code_execution_reply_using_executor(
@ -1334,6 +1348,8 @@ class ConversableAgent(LLMAgent):
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code executor."""
iostream = IOStream.get_default()
if config is not None:
raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.")
if self._code_execution_config is False:
@ -1371,7 +1387,7 @@ class ConversableAgent(LLMAgent):
num_code_blocks = len(code_blocks)
if num_code_blocks == 1:
print(
iostream.print(
colored(
f"\n>>>>>>>> EXECUTING CODE BLOCK (inferred language is {code_blocks[0].language})...",
"red",
@ -1379,7 +1395,7 @@ class ConversableAgent(LLMAgent):
flush=True,
)
else:
print(
iostream.print(
colored(
f"\n>>>>>>>> EXECUTING {num_code_blocks} CODE BLOCKS (inferred languages are [{', '.join([x.language for x in code_blocks])}])...",
"red",
@ -1631,6 +1647,8 @@ class ConversableAgent(LLMAgent):
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
iostream = IOStream.get_default()
if config is None:
config = self
if messages is None:
@ -1675,7 +1693,7 @@ class ConversableAgent(LLMAgent):
# print the no_human_input_msg
if no_human_input_msg:
print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
# stop the conversation
if reply == "exit":
@ -1715,7 +1733,7 @@ class ConversableAgent(LLMAgent):
# increment the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender] += 1
if self.human_input_mode != "NEVER":
print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
return False, None
@ -1742,6 +1760,8 @@ class ConversableAgent(LLMAgent):
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
iostream = IOStream.get_default()
if config is None:
config = self
if messages is None:
@ -1786,7 +1806,7 @@ class ConversableAgent(LLMAgent):
# print the no_human_input_msg
if no_human_input_msg:
print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True)
# stop the conversation
if reply == "exit":
@ -1826,7 +1846,7 @@ class ConversableAgent(LLMAgent):
# increment the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender] += 1
if self.human_input_mode != "NEVER":
print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True)
return False, None
@ -2001,7 +2021,9 @@ class ConversableAgent(LLMAgent):
Returns:
str: human input.
"""
reply = input(prompt)
iostream = IOStream.get_default()
reply = iostream.input(prompt)
self._human_input.append(reply)
return reply
@ -2016,8 +2038,8 @@ class ConversableAgent(LLMAgent):
Returns:
str: human input.
"""
reply = input(prompt)
self._human_input.append(reply)
loop = asyncio.get_running_loop()
reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt))
return reply
def run_code(self, code, **kwargs):
@ -2038,12 +2060,14 @@ class ConversableAgent(LLMAgent):
def execute_code_blocks(self, code_blocks):
"""Execute the code blocks and return the result."""
iostream = IOStream.get_default()
logs_all = ""
for i, code_block in enumerate(code_blocks):
lang, code = code_block
if not lang:
lang = infer_lang(code)
print(
iostream.print(
colored(
f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...",
"red",
@ -2124,6 +2148,8 @@ class ConversableAgent(LLMAgent):
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
"""
iostream = IOStream.get_default()
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
@ -2139,7 +2165,7 @@ class ConversableAgent(LLMAgent):
# Try to execute the function
if arguments is not None:
print(
iostream.print(
colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"),
flush=True,
)
@ -2152,7 +2178,7 @@ class ConversableAgent(LLMAgent):
content = f"Error: Function {func_name} not found."
if verbose:
print(
iostream.print(
colored(f"\nInput arguments: {arguments}\nOutput:\n{content}", "magenta"),
flush=True,
)
@ -2179,6 +2205,8 @@ class ConversableAgent(LLMAgent):
"function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0)
See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call
"""
iostream = IOStream.get_default()
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
@ -2194,7 +2222,7 @@ class ConversableAgent(LLMAgent):
# Try to execute the function
if arguments is not None:
print(
iostream.print(
colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"),
flush=True,
)
@ -2639,10 +2667,12 @@ class ConversableAgent(LLMAgent):
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""
iostream = IOStream.get_default()
if self.client is None:
print(f"No cost incurred from agent '{self.name}'.")
iostream.print(f"No cost incurred from agent '{self.name}'.")
else:
print(f"Agent '{self.name}':")
iostream.print(f"Agent '{self.name}':")
self.client.print_usage_summary(mode)
def get_actual_usage(self) -> Union[None, Dict[str, int]]:

View File

@ -5,9 +5,9 @@ import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from autogen.agentchat.agent import Agent
from autogen.agentchat.conversable_agent import ConversableAgent
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..io.base import IOStream
from ..code_utils import content_str
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
@ -257,22 +257,26 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
"""Manually select the next speaker."""
iostream = IOStream.get_default()
if agents is None:
agents = self.agents
print("Please select the next speaker from the following list:")
iostream.print("Please select the next speaker from the following list:")
_n_agents = len(agents)
for i in range(_n_agents):
print(f"{i+1}: {agents[i].name}")
iostream.print(f"{i+1}: {agents[i].name}")
try_count = 0
# Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking.
while try_count <= 3:
try_count += 1
if try_count >= 3:
print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
iostream.print(f"You have tried {try_count} times. The next speaker will be selected automatically.")
break
try:
i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ")
i = iostream.input(
"Enter the number of the next speaker (enter nothing or `q` to use auto selection): "
)
if i == "" or i == "q":
break
i = int(i)
@ -281,7 +285,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
else:
raise ValueError
except ValueError:
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None
def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
@ -740,6 +744,8 @@ class GroupChatManager(ConversableAgent):
reply (dict): reply message dict to analyze.
groupchat (GroupChat): GroupChat object.
"""
iostream = IOStream.get_default()
reply_content = reply["content"]
# Split the reply into words
words = reply_content.split()
@ -775,21 +781,21 @@ class GroupChatManager(ConversableAgent):
# clear history
if agent_to_memory_clear:
if nr_messages_to_preserve:
print(
iostream.print(
f"Clearing history for {agent_to_memory_clear.name} except last {nr_messages_to_preserve} messages."
)
else:
print(f"Clearing history for {agent_to_memory_clear.name}.")
iostream.print(f"Clearing history for {agent_to_memory_clear.name}.")
agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve)
else:
if nr_messages_to_preserve:
print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.")
iostream.print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.")
# clearing history for groupchat here
temp = groupchat.messages[-nr_messages_to_preserve:]
groupchat.messages.clear()
groupchat.messages.extend(temp)
else:
print("Clearing history for all agents.")
iostream.print("Clearing history for all agents.")
# clearing history for groupchat here
groupchat.messages.clear()
# clearing history for agents

View File

@ -30,7 +30,7 @@ class Cache:
ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
@staticmethod
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> Cache:
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache":
"""
Create a Redis cache instance.
@ -44,7 +44,7 @@ class Cache:
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})
@staticmethod
def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> Cache:
def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> "Cache":
"""
Create a Disk cache instance.
@ -81,7 +81,7 @@ class Cache:
self.config.get("cache_path_root", None),
)
def __enter__(self) -> AbstractCache:
def __enter__(self) -> "Cache":
"""
Enter the runtime context related to the cache object.

8
autogen/io/__init__.py Normal file
View File

@ -0,0 +1,8 @@
from .base import InputStream, IOStream, OutputStream
from .console import IOConsole
from .websockets import IOWebsockets
# Set the default input/output stream to the console
IOStream._default_io_stream.set(IOConsole())
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")

73
autogen/io/base.py Normal file
View File

@ -0,0 +1,73 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
__all__ = ("OutputStream", "InputStream", "IOStream")
@runtime_checkable
class OutputStream(Protocol):
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.
Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
... # pragma: no cover
@runtime_checkable
class InputStream(Protocol):
def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
... # pragma: no cover
@runtime_checkable
class IOStream(InputStream, OutputStream, Protocol):
"""A protocol for input/output streams."""
@staticmethod
def get_default() -> "IOStream":
"""Get the default input/output stream.
Returns:
IOStream: The default input/output stream.
"""
iostream = IOStream._default_io_stream.get()
if iostream is None:
raise RuntimeError("No default IOStream has been set")
return iostream
# ContextVar must be used in multithreaded or async environments
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
_default_io_stream.set(None)
@staticmethod
@contextmanager
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
"""Set the default input/output stream.
Args:
stream (IOStream): The input/output stream to set as the default.
"""
global _default_io_stream
try:
token = IOStream._default_io_stream.set(stream)
yield
finally:
IOStream._default_io_stream.reset(token)
return

37
autogen/io/console.py Normal file
View File

@ -0,0 +1,37 @@
import getpass
from typing import Any
from .base import IOStream
__all__ = ("IOConsole",)
class IOConsole(IOStream):
"""A console input/output stream."""
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.
Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
print(*objects, sep=sep, end=end, flush=flush)
def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
if password:
return getpass.getpass(prompt if prompt != "" else "Password: ")
return input(prompt)

207
autogen/io/websockets.py Normal file
View File

@ -0,0 +1,207 @@
import logging
import ssl
import threading
from contextlib import contextmanager
from functools import partial
from time import sleep
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, TYPE_CHECKING, Protocol, Union
from .base import IOStream
# Check if the websockets module is available
try:
from websockets.sync.server import serve as ws_serve
except ImportError as e:
_import_error: Optional[ImportError] = e
else:
_import_error = None
__all__ = ("IOWebsockets",)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# The following type and protocols are used to define the ServerConnection and WebSocketServer classes
# if websockets is not installed, they would be untyped
Data = Union[str, bytes]
class ServerConnection(Protocol):
def send(self, message: Union[Data, Iterable[Data]]) -> None:
"""Send a message to the client.
Args:
message (Union[Data, Iterable[Data]]): The message to send.
"""
... # pragma: no cover
def recv(self, timeout: Optional[float] = None) -> Data:
"""Receive a message from the client.
Args:
timeout (Optional[float], optional): The timeout for the receive operation. Defaults to None.
Returns:
Data: The message received from the client.
"""
... # pragma: no cover
def close(self) -> None:
"""Close the connection."""
...
class WebSocketServer(Protocol):
def serve_forever(self) -> None:
"""Run the server forever."""
... # pragma: no cover
def shutdown(self) -> None:
"""Shutdown the server."""
... # pragma: no cover
def __enter__(self) -> "WebSocketServer":
"""Enter the server context."""
... # pragma: no cover
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""Exit the server context."""
... # pragma: no cover
class IOWebsockets(IOStream):
"""A websocket input/output stream."""
def __init__(self, websocket: ServerConnection) -> None:
"""Initialize the websocket input/output stream.
Args:
websocket (ServerConnection): The websocket server.
Raises:
ImportError: If the websockets module is not available.
"""
if _import_error is not None:
raise _import_error # pragma: no cover
self._websocket = websocket
@staticmethod
def _handler(websocket: ServerConnection, on_connect: Callable[["IOWebsockets"], None]) -> None:
"""The handler function for the websocket server."""
logger.info(f" - IOWebsockets._handler(): Client connected on {websocket}")
# create a new IOWebsockets instance using the websocket that is create when a client connects
try:
iowebsocket = IOWebsockets(websocket)
with IOStream.set_default(iowebsocket):
# call the on_connect function
try:
on_connect(iowebsocket)
except Exception as e:
logger.warning(f" - IOWebsockets._handler(): Error in on_connect: {e}")
except Exception as e:
logger.error(f" - IOWebsockets._handler(): Unexpected error in IOWebsockets: {e}")
@staticmethod
@contextmanager
def run_server_in_thread(
*,
host: str = "127.0.0.1",
port: int = 8765,
on_connect: Callable[["IOWebsockets"], None],
ssl_context: Optional[ssl.SSLContext] = None,
**kwargs: Any,
) -> Iterator[str]:
"""Factory function to create a websocket input/output stream.
Args:
host (str, optional): The host to bind the server to. Defaults to "127.0.0.1".
port (int, optional): The port to bind the server to. Defaults to 8765.
on_connect (Callable[[IOWebsockets], None]): The function to be executed on client connection. Typically creates agents and initiate chat.
ssl_context (Optional[ssl.SSLContext], optional): The SSL context to use for secure connections. Defaults to None.
kwargs (Any): Additional keyword arguments to pass to the websocket server.
Yields:
str: The URI of the websocket server.
"""
server_dict: Dict[str, WebSocketServer] = {}
def _run_server() -> None:
if _import_error is not None:
raise _import_error
# print(f" - _run_server(): starting server on ws://{host}:{port}", flush=True)
with ws_serve(
handler=partial(IOWebsockets._handler, on_connect=on_connect),
host=host,
port=port,
ssl_context=ssl_context,
**kwargs,
) as server:
# print(f" - _run_server(): server {server} started on ws://{host}:{port}", flush=True)
server_dict["server"] = server
# runs until the server is shutdown
server.serve_forever()
return
# start server in a separate thread
thread = threading.Thread(target=_run_server)
thread.start()
try:
while "server" not in server_dict:
sleep(0.1)
yield f"ws://{host}:{port}"
finally:
# print(f" - run_server_in_thread(): shutting down server on ws://{host}:{port}", flush=True)
# gracefully stop server
if "server" in server_dict:
# print(f" - run_server_in_thread(): shutting down server {server_dict['server']}", flush=True)
server_dict["server"].shutdown()
# wait for the thread to stop
if thread:
thread.join()
@property
def websocket(self) -> "ServerConnection":
"""The URI of the websocket server."""
return self._websocket
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.
Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
xs = sep.join(map(str, objects)) + end
self._websocket.send(xs)
def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.
Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.
Returns:
str: The line read from the input stream.
"""
if prompt != "":
self._websocket.send(prompt)
msg = self._websocket.recv()
return msg.decode("utf-8") if isinstance(msg, bytes) else msg

View File

@ -11,6 +11,7 @@ from pydantic import BaseModel
from typing import Protocol
from autogen.cache.cache import Cache
from autogen.io.base import IOStream
from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
@ -156,6 +157,8 @@ class OpenAIClient:
Returns:
The completion.
"""
iostream = IOStream.get_default()
completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
@ -164,7 +167,7 @@ class OpenAIClient:
completion_tokens = 0
# Set the terminal text color to green
print("\033[32m", end="")
iostream.print("\033[32m", end="")
# Prepare for potential function call
full_function_call: Optional[Dict[str, Any]] = None
@ -216,15 +219,15 @@ class OpenAIClient:
# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
iostream.print(content, end="", flush=True)
response_contents[choice.index] += content
completion_tokens += 1
else:
# print()
# iostream.print()
pass
# Reset the terminal text color
print("\033[0m\n")
iostream.print("\033[0m\n")
# Prepare the final ChatCompletion object based on the accumulated data
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
@ -825,25 +828,26 @@ class OpenAIWrapper:
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""
iostream = IOStream.get_default()
def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary is None:
print("No actual cost incurred (all completions are using cache).", flush=True)
iostream.print("No actual cost incurred (all completions are using cache).", flush=True)
return
print(f"Usage summary {word_from_type} cached usage: ", flush=True)
print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
iostream.print(f"Usage summary {word_from_type} cached usage: ", flush=True)
iostream.print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
for model, counts in usage_summary.items():
if model == "total_cost":
continue #
print(
iostream.print(
f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
flush=True,
)
if self.total_usage_summary is None:
print('No usage summary. Please call "create" first.', flush=True)
iostream.print('No usage summary. Please call "create" first.', flush=True)
return
if isinstance(mode, list):
@ -856,14 +860,14 @@ class OpenAIWrapper:
elif "total" in mode:
mode = "total"
print("-" * 100, flush=True)
iostream.print("-" * 100, flush=True)
if mode == "both":
print_usage(self.actual_usage_summary, "actual")
print()
iostream.print()
if self.total_usage_summary != self.actual_usage_summary:
print_usage(self.total_usage_summary, "total")
else:
print(
iostream.print(
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
flush=True,
)
@ -873,7 +877,7 @@ class OpenAIWrapper:
print_usage(self.actual_usage_summary, "actual")
else:
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
print("-" * 100, flush=True)
iostream.print("-" * 100, flush=True)
def clear_usage_summary(self) -> None:
"""Clear the usage summary."""

View File

@ -0,0 +1,572 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "9a71fa36",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/autogen/blob/main/notebook/agentchat_websockets.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"# Websockets: Streaming input and output using websockets\n",
"\n",
"This notebook demonstrates how to use the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class to stream both input and output using websockets. The use of websockets allows you to build web clients that are more responsive than the one using web methods. The main difference is that the webosockets allows you to push data while you need to poll the server for new response using web mothods.\n",
"\n",
"\n",
"In this guide, we explore the capabilities of the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class. It is specifically designed to enhance the development of clients such as web clients which use websockets for streaming both input and output. The [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) stands out by enabling a more dynamic and interactive user experience for web applications.\n",
"\n",
"Websockets technology is at the core of this functionality, offering a significant advancement over traditional web methods by allowing data to be \"pushed\" to the client in real-time. This is a departure from the conventional approach where clients must repeatedly \"poll\" the server to check for any new responses. By employing the underlining [websockets](https://websockets.readthedocs.io/) library, the IOStream class facilitates a continuous, two-way communication channel between the server and client. This ensures that updates are received instantly, without the need for constant polling, thereby making web clients more efficient and responsive.\n",
"\n",
"The real power of websockets, leveraged through the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class, lies in its ability to create highly responsive web clients. This responsiveness is critical for applications requiring real-time data updates such as chat applications. By integrating the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class into your web application, you not only enhance user experience through immediate data transmission but also reduce the load on your server by eliminating unnecessary polling.\n",
"\n",
"In essence, the transition to using websockets through the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class marks a significant enhancement in web client development. This approach not only streamlines the data exchange process between clients and servers but also opens up new possibilities for creating more interactive and engaging web applications. By following this guide, developers can harness the full potential of websockets and the [`IOStream`](https://microsoft.github.io/autogen/docs/reference/io/base/IOStream) class to push the boundaries of what is possible with web client responsiveness and interactivity.\n",
"\n",
"## Requirements\n",
"\n",
"````{=mdx}\n",
":::info Requirements\n",
"Some extra dependencies are needed for this notebook, which can be installed via pip:\n",
"\n",
"```bash\n",
"pip install pyautogen[websockets] fastapi uvicorn\n",
"```\n",
"\n",
"For more information, please refer to the [installation guide](/docs/installation/).\n",
":::\n",
"````"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5ebd2397",
"metadata": {},
"source": [
"## Set your API Endpoint\n",
"\n",
"The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "dca301a4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gpt-4\n"
]
}
],
"source": [
"from tempfile import TemporaryDirectory\n",
"\n",
"from websockets.sync.client import connect as ws_connect\n",
"\n",
"import autogen\n",
"from autogen.cache import Cache\n",
"from autogen.io.websockets import IOStream, IOWebsockets\n",
"\n",
"config_list = autogen.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n",
" filter_dict={\n",
" \"model\": [\"gpt-4\", \"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"],\n",
" },\n",
")\n",
"\n",
"print(config_list[0][\"model\"])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "92fde41f",
"metadata": {},
"source": [
"````{=mdx}\n",
":::tip\n",
"Learn more about configuring LLMs for agents [here](/docs/topics/llm_configuration).\n",
":::\n",
"````"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2b9526e7",
"metadata": {},
"source": [
"## Defining `on_connect` function\n",
"\n",
"An `on_connect` function is a crucial part of applications that utilize websockets, acting as an event handler that is called whenever a new client connection is established. This function is designed to initiate any necessary setup, communication protocols, or data exchange procedures specific to the newly connected client. Essentially, it lays the groundwork for the interactive session that follows, configuring how the server and the client will communicate and what initial actions are to be taken once a connection is made. Now, let's delve into the details of how to define this function, especially in the context of using the AutoGen framework with websockets.\n",
"\n",
"\n",
"Upon a client's connection to the websocket server, the server automatically initiates a new instance of the [`IOWebsockets`](https://microsoft.github.io/autogen/docs/reference/io/websockets/IOWebsockets) class. This instance is crucial for managing the data flow between the server and the client. The `on_connect` function leverages this instance to set up the communication protocol, define interaction rules, and initiate any preliminary data exchanges or configurations required for the client-server interaction to proceed smoothly.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9fb85afb",
"metadata": {},
"outputs": [],
"source": [
"def on_connect(iostream: IOWebsockets) -> None:\n",
" print(f\" - on_connect(): Connected to client using IOWebsockets {iostream}\", flush=True)\n",
"\n",
" print(\" - on_connect(): Receiving message from client.\", flush=True)\n",
"\n",
" initial_msg = iostream.input()\n",
"\n",
" llm_config = {\n",
" \"config_list\": config_list,\n",
" \"stream\": True,\n",
" }\n",
"\n",
" agent = autogen.ConversableAgent(\n",
" name=\"chatbot\",\n",
" system_message=\"Complete a task given to you and reply TERMINATE when the task is done. If asked about the weather, use tool weather_forecast(city) to get the weather forecast for a city.\",\n",
" llm_config=llm_config,\n",
" )\n",
"\n",
" # create a UserProxyAgent instance named \"user_proxy\"\n",
" user_proxy = autogen.UserProxyAgent(\n",
" name=\"user_proxy\",\n",
" system_message=\"A proxy for the user.\",\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\") and x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
" human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=10,\n",
" code_execution_config=False,\n",
" )\n",
"\n",
" @user_proxy.register_for_execution()\n",
" @agent.register_for_llm(description=\"Weather forecats for a city\")\n",
" def weather_forecast(city: str) -> str:\n",
" return f\"The weather forecast for {city} is sunny.\"\n",
"\n",
" # we will use a temporary directory as the cache path root to ensure fresh completion each time\n",
" with TemporaryDirectory() as cache_path_root:\n",
" with Cache.disk(cache_path_root=cache_path_root) as cache:\n",
" print(\n",
" f\" - on_connect(): Initiating chat with agent {agent} using message '{initial_msg}'\",\n",
" flush=True,\n",
" )\n",
" user_proxy.initiate_chat( # noqa: F704\n",
" agent,\n",
" message=initial_msg,\n",
" cache=cache,\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "a1124796",
"metadata": {},
"source": [
"Here's an explanation on how a typical `on_connect` function such as the one in the example above is defined:\n",
"\n",
"1. **Receiving Initial Message**: Immediately after establishing a connection, receive an initial message from the client. This step is crucial for understanding the client's request or initiating the conversation flow.\n",
"\n",
"2. **Receiving Initial Message**: Immediately after establishing a connection, receive an initial message from the client. This step is crucial for understanding the client's request or initiating the conversation flow.\n",
"\n",
"3. **Configure the LLM**: Define the configuration for your large language model (LLM), specifying the list of configurations and the streaming capability. This configuration will be used to tailor the behavior of your conversational agent.\n",
"\n",
"4. **Instantiate ConversableAgent and UserProxyAgent**: Create an instance of ConversableAgent with a specific system message and the LLM configuration. Similarly, create a UserProxyAgent instance, defining its termination condition, human input mode, and other relevant parameters.\n",
"\n",
"5. **Define Agent-specific Functions**: If your conversable agent requires executing specific tasks, such as fetching a weather forecast in the example below, define these functions within the on_connect scope. Decorate these functions accordingly to link them with your agents.\n",
"\n",
"5. **Initiate Conversation**: Finally, use the `initiate_chat` method of your `UserProxyAgent` to start the interaction with the conversable agent, passing the initial message and a cache mechanism for efficiency."
]
},
{
"cell_type": "markdown",
"id": "62ef868a",
"metadata": {},
"source": [
"## Testing websockets server with Python client\n",
"\n",
"Testing an `on_connect` function with a Python client involves simulating a client-server interaction to ensure the setup, data exchange, and communication protocols function as intended. Heres a brief explanation on how to conduct this test using a Python client:\n",
"\n",
"1. **Start the Websocket Server**: Use the `IOWebsockets.run_server_in_thread method` to start the server in a separate thread, specifying the on_connect function and the port. This method returns the URI of the running websocket server.\n",
"\n",
"2. **Connect to the Server**: Open a connection to the server using the returned URI. This simulates a client initiating a connection to your websocket server.\n",
"\n",
"3. **Send a Message to the Server**: Once connected, send a message from the client to the server. This tests the server's ability to receive messages through the established websocket connection.\n",
"\n",
"4. **Receive and Process Messages**: Implement a loop to continuously receive messages from the server. Decode the messages if necessary, and process them accordingly. This step verifies the server's ability to respond back to the client's request.\n",
"\n",
"This test scenario effectively evaluates the interaction between a client and a server using the `on_connect` function, by simulating a realistic message exchange. It ensures that the server can handle incoming connections, process messages, and communicate responses back to the client, all critical functionalities for a robust websocket-based application."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4fbe004d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" - test_setup() with websocket server running on ws://127.0.0.1:8765.\n",
" - on_connect(): Connected to client using IOWebsockets <autogen.io.websockets.IOWebsockets object at 0x75ad84aa0d60>\n",
" - on_connect(): Receiving message from client.\n",
" - Connected to server on ws://127.0.0.1:8765\n",
" - Sending message to server.\n",
" - on_connect(): Initiating chat with agent <autogen.agentchat.conversable_agent.ConversableAgent object at 0x75ad84a72b30> using message 'Check out the weather in Paris and write a poem about it.'\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"Check out the weather in Paris and write a poem about it.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[32m\u001b[32m\u001b[0m\n",
"\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_U5VR0hck9KhDFWPdvmo1Eoke): weather_forecast *****\u001b[0m\n",
"Arguments: \n",
"{\n",
" \"city\": \"Paris\"\n",
"}\n",
"\u001b[32m*********************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_U5VR0hck9KhDFWPdvmo1Eoke\" *****\u001b[0m\n",
"The weather forecast for Paris is sunny.\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[32m\u001b[32mIn the city of love, shines the sun above,\n",
"Paris basks in golden rays, a beautiful day to praise.\n",
"Strolling down the Champs Elysées, the warm light leads the way,\n",
"In the glow, silhouettes dance, a perfect setting for romance.\n",
"\n",
"In the sunlight, the Seine sparkles bright, reflecting the City of Light,\n",
"Not a cloud in the crystal-clear blue sky, as the doves sail high.\n",
"Sunny Paris so profound, beauty all around,\n",
"Alive under the radiant crown, she wears her sunlight like a gown.\n",
"\n",
"TERMINATE\u001b[0m\n",
"\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"In the city of love, shines the sun above,\n",
"Paris basks in golden rays, a beautiful day to praise.\n",
"Strolling down the Champs Elysées, the warm light leads the way,\n",
"In the glow, silhouettes dance, a perfect setting for romance.\n",
"\n",
"In the sunlight, the Seine sparkles bright, reflecting the City of Light,\n",
"Not a cloud in the crystal-clear blue sky, as the doves sail high.\n",
"Sunny Paris so profound, beauty all around,\n",
"Alive under the radiant crown, she wears her sunlight like a gown.\n",
"\n",
"TERMINATE\n",
"\n",
" - Received TERMINATE message. Exiting.\n"
]
}
],
"source": [
"with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:\n",
" print(f\" - test_setup() with websocket server running on {uri}.\", flush=True)\n",
"\n",
" with ws_connect(uri) as websocket:\n",
" print(f\" - Connected to server on {uri}\", flush=True)\n",
"\n",
" print(\" - Sending message to server.\", flush=True)\n",
" # websocket.send(\"2+2=?\")\n",
" websocket.send(\"Check out the weather in Paris and write a poem about it.\")\n",
"\n",
" while True:\n",
" message = websocket.recv()\n",
" message = message.decode(\"utf-8\") if isinstance(message, bytes) else message\n",
"\n",
" print(message, end=\"\", flush=True)\n",
"\n",
" if \"TERMINATE\" in message:\n",
" print()\n",
" print(\" - Received TERMINATE message. Exiting.\", flush=True)\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "3a656564",
"metadata": {},
"source": [
"## Testing websockets server running inside FastAPI server with HTML/JS client\n",
"\n",
"The code snippets below outlines an approach for testing an `on_connect` function in a web environment using [FastAPI](https://fastapi.tiangolo.com/) to serve a simple interactive HTML page. This method allows users to send messages through a web interface, which are then processed by the server running the AutoGen framework via websockets. Here's a step-by-step explanation:\n",
"\n",
"1. **FastAPI Application Setup**: The code initiates by importing necessary libraries and setting up a FastAPI application. FastAPI is a modern, fast web framework for building APIs with Python 3.7+ based on standard Python type hints.\n",
"\n",
"2. **HTML Template for User Interaction**: An HTML template is defined as a multi-line Python string, which includes a basic form for message input and a script for managing websocket communication. This template creates a user interface where messages can be sent to the server and responses are displayed dynamically.\n",
"\n",
"3. **Running the Websocket Server**: The `run_websocket_server` async context manager starts the websocket server using `IOWebsockets.run_server_in_thread` with the specified `on_connect` function and port. This server listens for incoming websocket connections.\n",
"\n",
"4. **FastAPI Route for Serving HTML Page**: A FastAPI route (`@app.get(\"/\")`) is defined to serve the HTML page to users. When a user accesses the root URL, the HTML content for the websocket chat is returned, allowing them to interact with the websocket server.\n",
"\n",
"5. **Starting the FastAPI Application**: Lastly, the FastAPI application is started using Uvicorn, an ASGI server, configured with the app and additional parameters as needed. The server is then launched to serve the FastAPI application, making the interactive HTML page accessible to users.\n",
"\n",
"This method of testing allows for interactive communication between the user and the server, providing a practical way to demonstrate and evaluate the behavior of the on_connect function in real-time. Users can send messages through the webpage, and the server processes these messages as per the logic defined in the on_connect function, showcasing the capabilities and responsiveness of the AutoGen framework's websocket handling in a user-friendly manner."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5e55dc06",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'fastapi'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcontextlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m asynccontextmanager \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpathlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Path \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastapi\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FastAPI \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfastapi\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mresponses\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m HTMLResponse \u001b[38;5;66;03m# noqa: E402\u001b[39;00m\n\u001b[1;32m 7\u001b[0m PORT \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m8000\u001b[39m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'fastapi'"
]
}
],
"source": [
"from contextlib import asynccontextmanager # noqa: E402\n",
"from pathlib import Path # noqa: E402\n",
"\n",
"from fastapi import FastAPI # noqa: E402\n",
"from fastapi.responses import HTMLResponse # noqa: E402\n",
"\n",
"PORT = 8000\n",
"\n",
"html = \"\"\"\n",
"<!DOCTYPE html>\n",
"<html>\n",
" <head>\n",
" <title>Autogen websocket test</title>\n",
" </head>\n",
" <body>\n",
" <h1>WebSocket Chat</h1>\n",
" <form action=\"\" onsubmit=\"sendMessage(event)\">\n",
" <input type=\"text\" id=\"messageText\" autocomplete=\"off\"/>\n",
" <button>Send</button>\n",
" </form>\n",
" <ul id='messages'>\n",
" </ul>\n",
" <script>\n",
" var ws = new WebSocket(\"ws://localhost:8080/ws\");\n",
" ws.onmessage = function(event) {\n",
" var messages = document.getElementById('messages')\n",
" var message = document.createElement('li')\n",
" var content = document.createTextNode(event.data)\n",
" message.appendChild(content)\n",
" messages.appendChild(message)\n",
" };\n",
" function sendMessage(event) {\n",
" var input = document.getElementById(\"messageText\")\n",
" ws.send(input.value)\n",
" input.value = ''\n",
" event.preventDefault()\n",
" }\n",
" </script>\n",
" </body>\n",
"</html>\n",
"\"\"\"\n",
"\n",
"\n",
"@asynccontextmanager\n",
"async def run_websocket_server(app):\n",
" with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8080) as uri:\n",
" print(f\"Websocket server started at {uri}.\", flush=True)\n",
"\n",
" yield\n",
"\n",
"\n",
"app = FastAPI(lifespan=run_websocket_server)\n",
"\n",
"\n",
"@app.get(\"/\")\n",
"async def get():\n",
" return HTMLResponse(html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d92e50b5",
"metadata": {},
"outputs": [],
"source": [
"import uvicorn # noqa: E402\n",
"\n",
"config = uvicorn.Config(app)\n",
"server = uvicorn.Server(config)\n",
"await server.serve() # noqa: F704"
]
},
{
"cell_type": "markdown",
"id": "1c8c9f61",
"metadata": {},
"source": [
"The testing setup described above, leveraging FastAPI and websockets, not only serves as a robust testing framework for the on_connect function but also lays the groundwork for developing real-world applications. This approach exemplifies how web-based interactions can be made dynamic and real-time, a critical aspect of modern application development.\n",
"\n",
"For instance, this setup can be directly applied or adapted to build interactive chat applications, real-time data dashboards, or live support systems. The integration of websockets enables the server to push updates to clients instantly, a key feature for applications that rely on the timely delivery of information. For example, a chat application built on this framework can support instantaneous messaging between users, enhancing user engagement and satisfaction.\n",
"\n",
"Moreover, the simplicity and interactivity of the HTML page used for testing reflect how user interfaces can be designed to provide seamless experiences. Developers can expand upon this foundation to incorporate more sophisticated elements such as user authentication, message encryption, and custom user interactions, further tailoring the application to meet specific use case requirements.\n",
"\n",
"The flexibility of the FastAPI framework, combined with the real-time communication enabled by websockets, provides a powerful toolset for developers looking to build scalable, efficient, and highly interactive web applications. Whether it's for creating collaborative platforms, streaming services, or interactive gaming experiences, this testing setup offers a glimpse into the potential applications that can be developed with these technologies."
]
},
{
"cell_type": "markdown",
"id": "cfb50946",
"metadata": {},
"source": [
"## Testing websockets server with HTML/JS client\n",
"\n",
"The provided code snippet below is an example of how to create an interactive testing environment for an `on_connect` function using Python's built-in `http.server` module. This setup allows for real-time interaction within a web browser, enabling developers to test the websocket functionality in a more user-friendly and practical manner. Here's a breakdown of how this code operates and its potential applications:\n",
"\n",
"1. **Serving a Simple HTML Page**: The code starts by defining an HTML page that includes a form for sending messages and a list to display incoming messages. JavaScript is used to handle the form submission and websocket communication.\n",
"\n",
"2. **Temporary Directory for HTML File**: A temporary directory is created to store the HTML file. This approach ensures that the testing environment is clean and isolated, minimizing conflicts with existing files or configurations.\n",
"\n",
"3. **Custom HTTP Request Handler**: A custom subclass of `SimpleHTTPRequestHandler` is defined to serve the HTML file. This handler overrides the do_GET method to redirect the root path (`/`) to the `chat.html` page, ensuring that visitors to the server's root URL are immediately presented with the chat interface.\n",
"\n",
"4. **Starting the Websocket Server**: Concurrently, a websocket server is started on a different port using the `IOWebsockets.run_server_in_thread` method, with the previously defined `on_connect` function as the callback for new connections.\n",
"\n",
"5. **HTTP Server for the HTML Interface**: An HTTP server is instantiated to serve the HTML chat interface, enabling users to interact with the websocket server through a web browser.\n",
"\n",
"This setup showcases a practical application of integrating websockets with a simple HTTP server to create a dynamic and interactive web application. By using Python's standard library modules, it demonstrates a low-barrier entry to developing real-time applications such as chat systems, live notifications, or interactive dashboards.\n",
"\n",
"The key takeaway from this code example is how easily Python's built-in libraries can be leveraged to prototype and test complex web functionalities. For developers looking to build real-world applications, this approach offers a straightforward method to validate and refine websocket communication logic before integrating it into larger frameworks or systems. The simplicity and accessibility of this testing setup make it an excellent starting point for developing a wide range of interactive web applications.\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "708a98de",
"metadata": {},
"outputs": [],
"source": [
"from http.server import HTTPServer, SimpleHTTPRequestHandler # noqa: E402\n",
"\n",
"PORT = 8000\n",
"\n",
"html = \"\"\"\n",
"<!DOCTYPE html>\n",
"<html>\n",
" <head>\n",
" <title>Autogen websocket test</title>\n",
" </head>\n",
" <body>\n",
" <h1>WebSocket Chat</h1>\n",
" <form action=\"\" onsubmit=\"sendMessage(event)\">\n",
" <input type=\"text\" id=\"messageText\" autocomplete=\"off\"/>\n",
" <button>Send</button>\n",
" </form>\n",
" <ul id='messages'>\n",
" </ul>\n",
" <script>\n",
" var ws = new WebSocket(\"ws://localhost:8080/ws\");\n",
" ws.onmessage = function(event) {\n",
" var messages = document.getElementById('messages')\n",
" var message = document.createElement('li')\n",
" var content = document.createTextNode(event.data)\n",
" message.appendChild(content)\n",
" messages.appendChild(message)\n",
" };\n",
" function sendMessage(event) {\n",
" var input = document.getElementById(\"messageText\")\n",
" ws.send(input.value)\n",
" input.value = ''\n",
" event.preventDefault()\n",
" }\n",
" </script>\n",
" </body>\n",
"</html>\n",
"\"\"\"\n",
"\n",
"with TemporaryDirectory() as temp_dir:\n",
" # create a simple HTTP webpage\n",
" path = Path(temp_dir) / \"chat.html\"\n",
" with open(path, \"w\") as f:\n",
" f.write(html)\n",
"\n",
" #\n",
" class MyRequestHandler(SimpleHTTPRequestHandler):\n",
" def __init__(self, *args, **kwargs):\n",
" super().__init__(*args, directory=temp_dir, **kwargs)\n",
"\n",
" def do_GET(self):\n",
" if self.path == \"/\":\n",
" self.path = \"/chat.html\"\n",
" return SimpleHTTPRequestHandler.do_GET(self)\n",
"\n",
" handler = MyRequestHandler\n",
"\n",
" with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8080) as uri:\n",
" print(f\"Websocket server started at {uri}.\", flush=True)\n",
"\n",
" with HTTPServer((\"\", PORT), handler) as httpd:\n",
" print(\"HTTP server started at http://localhost:\" + str(PORT))\n",
" httpd.serve_forever()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19656d0e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"front_matter": {
"description": "Websockets facilitate real-time, bidirectional communication between web clients and servers, enhancing the responsiveness and interactivity of AutoGen-powered applications.",
"tags": [
"websockets",
"streaming"
]
},
"kernelspec": {
"display_name": "flaml_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -67,6 +67,8 @@ files = [
"autogen/exception_utils.py",
"autogen/coding",
"autogen/oai/openai_utils.py",
"autogen/io",
"test/io",
]
exclude = [

View File

@ -65,8 +65,9 @@ setuptools.setup(
"graph": ["networkx", "matplotlib"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"websockets": ["websockets>=12.0,<13"],
"jupyter-executor": jupyter_executor,
"types": ["mypy==1.9.0"] + jupyter_executor,
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
},
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -157,7 +157,8 @@ def test_two_agents_logging(db_connection):
assert row["session_id"] and row["session_id"] == session_id
assert row["class"] in ["AzureOpenAI", "OpenAI"]
init_args = json.loads(row["init_args"])
assert "api_version" in init_args
if row["class"] == "AzureOpenAI":
assert "api_version" in init_args
assert row["timestamp"], "timestamp is empty"
# Verify oai wrapper table

28
test/io/test_base.py Normal file
View File

@ -0,0 +1,28 @@
from typing import Any
from autogen.io import IOConsole, IOStream, IOWebsockets
class TestIOStream:
def test_initial_default_io_stream(self) -> None:
assert isinstance(IOStream.get_default(), IOConsole)
def test_set_default_io_stream(self) -> None:
class MyIOStream(IOStream):
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
pass
def input(self, prompt: str = "", *, password: bool = False) -> str:
return "Hello, World!"
assert isinstance(IOStream.get_default(), IOConsole)
with IOStream.set_default(MyIOStream()):
assert isinstance(IOStream.get_default(), MyIOStream)
with IOStream.set_default(IOConsole()):
assert isinstance(IOStream.get_default(), IOConsole)
assert isinstance(IOStream.get_default(), MyIOStream)
assert isinstance(IOStream.get_default(), IOConsole)

38
test/io/test_console.py Normal file
View File

@ -0,0 +1,38 @@
from unittest.mock import MagicMock, patch
import pytest
from autogen.io import IOConsole
class TestConsoleIO:
def setup_method(self) -> None:
self.console_io = IOConsole()
@patch("builtins.print")
def test_print(self, mock_print: MagicMock) -> None:
# calling the print method should call the mock of the builtin print function
self.console_io.print("Hello, World!", flush=True)
mock_print.assert_called_once_with("Hello, World!", end="\n", sep=" ", flush=True)
@patch("builtins.input")
def test_input(self, mock_input: MagicMock) -> None:
# calling the input method should call the mock of the builtin input function
mock_input.return_value = "Hello, World!"
actual = self.console_io.input("Hi!")
assert actual == "Hello, World!"
mock_input.assert_called_once_with("Hi!")
@pytest.mark.parametrize("prompt", ["", "Password: ", "Enter you password:"])
def test_input_password(self, monkeypatch: pytest.MonkeyPatch, prompt: str) -> None:
mock_getpass = MagicMock()
mock_getpass.return_value = "123456"
monkeypatch.setattr("getpass.getpass", mock_getpass)
actual = self.console_io.input(prompt, password=True)
assert actual == "123456"
if prompt == "":
mock_getpass.assert_called_once_with("Password: ")
else:
mock_getpass.assert_called_once_with(prompt)

176
test/io/test_websockets.py Normal file
View File

@ -0,0 +1,176 @@
from tempfile import TemporaryDirectory
from typing import Dict
import pytest
from autogen.io.base import IOStream
from conftest import skip_openai
import autogen
from autogen.cache.cache import Cache
from autogen.io import IOWebsockets
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
# Check if the websockets module is available
try:
from websockets.sync.client import connect as ws_connect
except ImportError: # pragma: no cover
skip_test = True
else:
skip_test = False
@pytest.mark.skipif(skip_test, reason="websockets module is not available")
class TestConsoleIOWithWebsockets:
def test_input_print(self) -> None:
print()
print("Testing input/print", flush=True)
def on_connect(iostream: IOWebsockets) -> None:
print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True)
print(" - on_connect(): Receiving message from client.", flush=True)
msg = iostream.input()
print(f" - on_connect(): Received message '{msg}' from client.", flush=True)
assert msg == "Hello world!"
for msg in ["Hello, World!", "Over and out!"]:
print(f" - on_connect(): Sending message '{msg}' to client.", flush=True)
iostream.print(msg)
print(" - on_connect(): Receiving message from client.", flush=True)
msg = iostream.input("May I?")
print(f" - on_connect(): Received message '{msg}' from client.", flush=True)
assert msg == "Yes"
return
with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:
print(f" - test_setup() with websocket server running on {uri}.", flush=True)
with ws_connect(uri) as websocket:
print(f" - Connected to server on {uri}", flush=True)
print(" - Sending message to server.", flush=True)
websocket.send("Hello world!")
for expected in ["Hello, World!", "Over and out!", "May I?"]:
print(" - Receiving message from server.", flush=True)
message = websocket.recv()
message = message.decode("utf-8") if isinstance(message, bytes) else message
# drop the newline character
if message.endswith("\n"):
message = message[:-1]
print(
f" - Asserting received message '{message}' is the same as the expected message '{expected}'",
flush=True,
)
assert message == expected
print(" - Sending message 'Yes' to server.", flush=True)
websocket.send("Yes")
print("Test passed.", flush=True)
@pytest.mark.skipif(skip_openai, reason="requested to skip")
def test_chat(self) -> None:
print("Testing setup", flush=True)
success_dict = {"success": False}
def on_connect(iostream: IOWebsockets, success_dict: Dict[str, bool] = success_dict) -> None:
print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True)
print(" - on_connect(): Receiving message from client.", flush=True)
initial_msg = iostream.input()
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"model": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-0314",
"gpt4",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-v0314",
],
},
file_location=KEY_LOC,
)
llm_config = {
"config_list": config_list,
"stream": True,
}
agent = autogen.ConversableAgent(
name="chatbot",
system_message="Complete a task given to you and reply TERMINATE when the task is done.",
llm_config=llm_config,
)
# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
system_message="A proxy for the user.",
is_termination_msg=lambda x: x.get("content", "")
and x.get("content", "").rstrip().endswith("TERMINATE"),
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
)
# we will use a temporary directory as the cache path root to ensure fresh completion each time
with TemporaryDirectory() as cache_path_root:
with Cache.disk(cache_path_root=cache_path_root) as cache:
print(
f" - on_connect(): Initiating chat with agent {agent} using message '{initial_msg}'",
flush=True,
)
user_proxy.initiate_chat( # noqa: F704
agent,
message=initial_msg,
cache=cache,
)
success_dict["success"] = True
return
with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:
print(f" - test_setup() with websocket server running on {uri}.", flush=True)
with ws_connect(uri) as websocket:
print(f" - Connected to server on {uri}", flush=True)
print(" - Sending message to server.", flush=True)
# websocket.send("2+2=?")
websocket.send("Please write a poem about spring in a city of your choice.")
while True:
message = websocket.recv()
message = message.decode("utf-8") if isinstance(message, bytes) else message
# drop the newline character
if message.endswith("\n"):
message = message[:-1]
print(message, end="", flush=True)
if "TERMINATE" in message:
print()
print(" - Received TERMINATE message. Exiting.", flush=True)
break
assert success_dict["success"]
print("Test passed.", flush=True)