import pytest from autogen import AssistantAgent, UserProxyAgent import sys sys.path.append("samples/tools/finetuning") from finetuning import update_model # noqa: E402 from typing import Dict # noqa: E402 sys.path.append("test") TEST_CUSTOM_RESPONSE = "This is a custom response." TEST_LOCAL_MODEL_NAME = "local_model_name" def test_custom_model_client(): TEST_LOSS = 0.5 class UpdatableCustomModel: def __init__(self, config: Dict): self.model = config["model"] self.model_name = config["model"] def create(self, params): from types import SimpleNamespace response = SimpleNamespace() # need to follow Client.ClientResponseProtocol response.choices = [] choice = SimpleNamespace() choice.message = SimpleNamespace() choice.message.content = TEST_CUSTOM_RESPONSE response.choices.append(choice) response.model = self.model return response def message_retrieval(self, response): return [response.choices[0].message.content] def cost(self, response) -> float: """Calculate the cost of the response.""" response.cost = 0 return 0 @staticmethod def get_usage(response) -> Dict: return {} def update_model(self, preference_data, messages, **kwargs): return {"loss": TEST_LOSS} config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}] assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", llm_config={"config_list": config_list}, ) assistant.register_model_client(model_client_cls=UpdatableCustomModel) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", max_consecutive_auto_reply=1, code_execution_config=False, llm_config=False, ) res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True) response_content = res.summary assert response_content == TEST_CUSTOM_RESPONSE preference_data = [("this is what the response should have been like", response_content)] update_model_stats = update_model(assistant, preference_data, user_proxy) assert update_model_stats["update_stats"]["loss"] == TEST_LOSS def test_update_model_without_client_raises_error(): assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", max_consecutive_auto_reply=0, llm_config=False, code_execution_config=False, ) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", max_consecutive_auto_reply=1, code_execution_config=False, llm_config=False, ) user_proxy.initiate_chat(assistant, message="2+2=", silent=True) with pytest.raises(ValueError): update_model(assistant, [], user_proxy) def test_custom_model_update_func_missing_raises_error(): class UpdatableCustomModel: def __init__(self, config: Dict): self.model = config["model"] self.model_name = config["model"] def create(self, params): from types import SimpleNamespace response = SimpleNamespace() # need to follow Client.ClientResponseProtocol response.choices = [] choice = SimpleNamespace() choice.message = SimpleNamespace() choice.message.content = TEST_CUSTOM_RESPONSE response.choices.append(choice) response.model = self.model return response def message_retrieval(self, response): return [response.choices[0].message.content] def cost(self, response) -> float: """Calculate the cost of the response.""" response.cost = 0 return 0 @staticmethod def get_usage(response) -> Dict: return {} config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}] assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", llm_config={"config_list": config_list}, ) assistant.register_model_client(model_client_cls=UpdatableCustomModel) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", max_consecutive_auto_reply=1, code_execution_config=False, llm_config=False, ) res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True) response_content = res.summary assert response_content == TEST_CUSTOM_RESPONSE with pytest.raises(NotImplementedError): update_model(assistant, [], user_proxy) def test_multiple_model_clients_raises_error(): class UpdatableCustomModel: def __init__(self, config: Dict): self.model = config["model"] self.model_name = config["model"] def create(self, params): from types import SimpleNamespace response = SimpleNamespace() # need to follow Client.ClientResponseProtocol response.choices = [] choice = SimpleNamespace() choice.message = SimpleNamespace() choice.message.content = TEST_CUSTOM_RESPONSE response.choices.append(choice) response.model = self.model return response def message_retrieval(self, response): return [response.choices[0].message.content] def cost(self, response) -> float: """Calculate the cost of the response.""" response.cost = 0 return 0 @staticmethod def get_usage(response) -> Dict: return {} def update_model(self, preference_data, messages, **kwargs): return {} config_list = [ {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}, {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}, ] assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", llm_config={"config_list": config_list}, ) assistant.register_model_client(model_client_cls=UpdatableCustomModel) assistant.register_model_client(model_client_cls=UpdatableCustomModel) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", max_consecutive_auto_reply=1, code_execution_config=False, llm_config=False, ) user_proxy.initiate_chat(assistant, message="2+2=", silent=True) with pytest.raises(ValueError): update_model(assistant, [], user_proxy)