autogen/samples/tools/finetuning/tests/test_conversable_agent_update_model.py
Li Jiang 42b27b9a9d
Add isort (#2265)
* Add isort

* Apply isort on py files

* Fix circular import

* Fix format for notebooks

* Fix format

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
2024-04-05 02:26:06 +00:00

220 lines
6.8 KiB
Python

import sys
import pytest
from autogen import AssistantAgent, UserProxyAgent
sys.path.append("samples/tools/finetuning")
from typing import Dict # noqa: E402
from finetuning import update_model # noqa: E402
sys.path.append("test")
TEST_CUSTOM_RESPONSE = "This is a custom response."
TEST_LOCAL_MODEL_NAME = "local_model_name"
def test_custom_model_client():
TEST_LOSS = 0.5
class UpdatableCustomModel:
def __init__(self, config: Dict):
self.model = config["model"]
self.model_name = config["model"]
def create(self, params):
from types import SimpleNamespace
response = SimpleNamespace()
# need to follow Client.ClientResponseProtocol
response.choices = []
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = TEST_CUSTOM_RESPONSE
response.choices.append(choice)
response.model = self.model
return response
def message_retrieval(self, response):
return [response.choices[0].message.content]
def cost(self, response) -> float:
"""Calculate the cost of the response."""
response.cost = 0
return 0
@staticmethod
def get_usage(response) -> Dict:
return {}
def update_model(self, preference_data, messages, **kwargs):
return {"loss": TEST_LOSS}
config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}]
assistant = AssistantAgent(
"assistant",
system_message="You are a helpful assistant.",
human_input_mode="NEVER",
llm_config={"config_list": config_list},
)
assistant.register_model_client(model_client_cls=UpdatableCustomModel)
user_proxy = UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=1,
code_execution_config=False,
llm_config=False,
)
res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True)
response_content = res.summary
assert response_content == TEST_CUSTOM_RESPONSE
preference_data = [("this is what the response should have been like", response_content)]
update_model_stats = update_model(assistant, preference_data, user_proxy)
assert update_model_stats["update_stats"]["loss"] == TEST_LOSS
def test_update_model_without_client_raises_error():
assistant = AssistantAgent(
"assistant",
system_message="You are a helpful assistant.",
human_input_mode="NEVER",
max_consecutive_auto_reply=0,
llm_config=False,
code_execution_config=False,
)
user_proxy = UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=1,
code_execution_config=False,
llm_config=False,
)
user_proxy.initiate_chat(assistant, message="2+2=", silent=True)
with pytest.raises(ValueError):
update_model(assistant, [], user_proxy)
def test_custom_model_update_func_missing_raises_error():
class UpdatableCustomModel:
def __init__(self, config: Dict):
self.model = config["model"]
self.model_name = config["model"]
def create(self, params):
from types import SimpleNamespace
response = SimpleNamespace()
# need to follow Client.ClientResponseProtocol
response.choices = []
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = TEST_CUSTOM_RESPONSE
response.choices.append(choice)
response.model = self.model
return response
def message_retrieval(self, response):
return [response.choices[0].message.content]
def cost(self, response) -> float:
"""Calculate the cost of the response."""
response.cost = 0
return 0
@staticmethod
def get_usage(response) -> Dict:
return {}
config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}]
assistant = AssistantAgent(
"assistant",
system_message="You are a helpful assistant.",
human_input_mode="NEVER",
llm_config={"config_list": config_list},
)
assistant.register_model_client(model_client_cls=UpdatableCustomModel)
user_proxy = UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=1,
code_execution_config=False,
llm_config=False,
)
res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True)
response_content = res.summary
assert response_content == TEST_CUSTOM_RESPONSE
with pytest.raises(NotImplementedError):
update_model(assistant, [], user_proxy)
def test_multiple_model_clients_raises_error():
class UpdatableCustomModel:
def __init__(self, config: Dict):
self.model = config["model"]
self.model_name = config["model"]
def create(self, params):
from types import SimpleNamespace
response = SimpleNamespace()
# need to follow Client.ClientResponseProtocol
response.choices = []
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = TEST_CUSTOM_RESPONSE
response.choices.append(choice)
response.model = self.model
return response
def message_retrieval(self, response):
return [response.choices[0].message.content]
def cost(self, response) -> float:
"""Calculate the cost of the response."""
response.cost = 0
return 0
@staticmethod
def get_usage(response) -> Dict:
return {}
def update_model(self, preference_data, messages, **kwargs):
return {}
config_list = [
{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"},
{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"},
]
assistant = AssistantAgent(
"assistant",
system_message="You are a helpful assistant.",
human_input_mode="NEVER",
llm_config={"config_list": config_list},
)
assistant.register_model_client(model_client_cls=UpdatableCustomModel)
assistant.register_model_client(model_client_cls=UpdatableCustomModel)
user_proxy = UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=1,
code_execution_config=False,
llm_config=False,
)
user_proxy.initiate_chat(assistant, message="2+2=", silent=True)
with pytest.raises(ValueError):
update_model(assistant, [], user_proxy)