mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-01 05:12:22 +00:00

* Add isort * Apply isort on py files * Fix circular import * Fix format for notebooks * Fix format --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
177 lines
6.4 KiB
Python
177 lines
6.4 KiB
Python
from tempfile import TemporaryDirectory
|
|
from typing import Dict
|
|
|
|
import pytest
|
|
from conftest import skip_openai
|
|
|
|
import autogen
|
|
from autogen.cache.cache import Cache
|
|
from autogen.io import IOWebsockets
|
|
from autogen.io.base import IOStream
|
|
|
|
KEY_LOC = "notebook"
|
|
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
|
|
|
# Check if the websockets module is available
|
|
try:
|
|
from websockets.sync.client import connect as ws_connect
|
|
except ImportError: # pragma: no cover
|
|
skip_test = True
|
|
else:
|
|
skip_test = False
|
|
|
|
|
|
@pytest.mark.skipif(skip_test, reason="websockets module is not available")
|
|
class TestConsoleIOWithWebsockets:
|
|
def test_input_print(self) -> None:
|
|
print()
|
|
print("Testing input/print", flush=True)
|
|
|
|
def on_connect(iostream: IOWebsockets) -> None:
|
|
print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True)
|
|
|
|
print(" - on_connect(): Receiving message from client.", flush=True)
|
|
|
|
msg = iostream.input()
|
|
|
|
print(f" - on_connect(): Received message '{msg}' from client.", flush=True)
|
|
|
|
assert msg == "Hello world!"
|
|
|
|
for msg in ["Hello, World!", "Over and out!"]:
|
|
print(f" - on_connect(): Sending message '{msg}' to client.", flush=True)
|
|
|
|
iostream.print(msg)
|
|
|
|
print(" - on_connect(): Receiving message from client.", flush=True)
|
|
|
|
msg = iostream.input("May I?")
|
|
|
|
print(f" - on_connect(): Received message '{msg}' from client.", flush=True)
|
|
assert msg == "Yes"
|
|
|
|
return
|
|
|
|
with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:
|
|
print(f" - test_setup() with websocket server running on {uri}.", flush=True)
|
|
|
|
with ws_connect(uri) as websocket:
|
|
print(f" - Connected to server on {uri}", flush=True)
|
|
|
|
print(" - Sending message to server.", flush=True)
|
|
websocket.send("Hello world!")
|
|
|
|
for expected in ["Hello, World!", "Over and out!", "May I?"]:
|
|
print(" - Receiving message from server.", flush=True)
|
|
message = websocket.recv()
|
|
message = message.decode("utf-8") if isinstance(message, bytes) else message
|
|
# drop the newline character
|
|
if message.endswith("\n"):
|
|
message = message[:-1]
|
|
|
|
print(
|
|
f" - Asserting received message '{message}' is the same as the expected message '{expected}'",
|
|
flush=True,
|
|
)
|
|
assert message == expected
|
|
|
|
print(" - Sending message 'Yes' to server.", flush=True)
|
|
websocket.send("Yes")
|
|
|
|
print("Test passed.", flush=True)
|
|
|
|
@pytest.mark.skipif(skip_openai, reason="requested to skip")
|
|
def test_chat(self) -> None:
|
|
print("Testing setup", flush=True)
|
|
|
|
success_dict = {"success": False}
|
|
|
|
def on_connect(iostream: IOWebsockets, success_dict: Dict[str, bool] = success_dict) -> None:
|
|
print(f" - on_connect(): Connected to client using IOWebsockets {iostream}", flush=True)
|
|
|
|
print(" - on_connect(): Receiving message from client.", flush=True)
|
|
|
|
initial_msg = iostream.input()
|
|
|
|
config_list = autogen.config_list_from_json(
|
|
OAI_CONFIG_LIST,
|
|
filter_dict={
|
|
"model": [
|
|
"gpt-3.5-turbo",
|
|
"gpt-3.5-turbo-16k",
|
|
"gpt-4",
|
|
"gpt-4-0314",
|
|
"gpt4",
|
|
"gpt-4-32k",
|
|
"gpt-4-32k-0314",
|
|
"gpt-4-32k-v0314",
|
|
],
|
|
},
|
|
file_location=KEY_LOC,
|
|
)
|
|
|
|
llm_config = {
|
|
"config_list": config_list,
|
|
"stream": True,
|
|
}
|
|
|
|
agent = autogen.ConversableAgent(
|
|
name="chatbot",
|
|
system_message="Complete a task given to you and reply TERMINATE when the task is done.",
|
|
llm_config=llm_config,
|
|
)
|
|
|
|
# create a UserProxyAgent instance named "user_proxy"
|
|
user_proxy = autogen.UserProxyAgent(
|
|
name="user_proxy",
|
|
system_message="A proxy for the user.",
|
|
is_termination_msg=lambda x: x.get("content", "")
|
|
and x.get("content", "").rstrip().endswith("TERMINATE"),
|
|
human_input_mode="NEVER",
|
|
max_consecutive_auto_reply=10,
|
|
)
|
|
|
|
# we will use a temporary directory as the cache path root to ensure fresh completion each time
|
|
with TemporaryDirectory() as cache_path_root:
|
|
with Cache.disk(cache_path_root=cache_path_root) as cache:
|
|
print(
|
|
f" - on_connect(): Initiating chat with agent {agent} using message '{initial_msg}'",
|
|
flush=True,
|
|
)
|
|
user_proxy.initiate_chat( # noqa: F704
|
|
agent,
|
|
message=initial_msg,
|
|
cache=cache,
|
|
)
|
|
|
|
success_dict["success"] = True
|
|
|
|
return
|
|
|
|
with IOWebsockets.run_server_in_thread(on_connect=on_connect, port=8765) as uri:
|
|
print(f" - test_setup() with websocket server running on {uri}.", flush=True)
|
|
|
|
with ws_connect(uri) as websocket:
|
|
print(f" - Connected to server on {uri}", flush=True)
|
|
|
|
print(" - Sending message to server.", flush=True)
|
|
# websocket.send("2+2=?")
|
|
websocket.send("Please write a poem about spring in a city of your choice.")
|
|
|
|
while True:
|
|
message = websocket.recv()
|
|
message = message.decode("utf-8") if isinstance(message, bytes) else message
|
|
# drop the newline character
|
|
if message.endswith("\n"):
|
|
message = message[:-1]
|
|
|
|
print(message, end="", flush=True)
|
|
|
|
if "TERMINATE" in message:
|
|
print()
|
|
print(" - Received TERMINATE message. Exiting.", flush=True)
|
|
break
|
|
|
|
assert success_dict["success"]
|
|
print("Test passed.", flush=True)
|