From 978cbd2e8995ce22d2cbadd92a46816a0a3287a0 Mon Sep 17 00:00:00 2001 From: EeS Date: Tue, 13 May 2025 11:32:14 +0900 Subject: [PATCH] FIX/mistral could not recive name field (#6503) ## Why are these changes needed? FIX/mistral could not recive name field, so add model transformer for mistral ## Related issue number Closes #6147 ## Checks - [ ] I've included any doc changes needed for . See to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. Co-authored-by: Eric Zhu --- .../src/autogen_core/models/_model_client.py | 25 +++++++++ .../models/openai/_message_transform.py | 51 ++++++++++++++++++- .../tests/models/test_openai_model_client.py | 13 +++++ 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index b65a19853..dac7b4daa 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -37,9 +37,15 @@ class ModelFamily: CLAUDE_3_5_HAIKU = "claude-3-5-haiku" CLAUDE_3_5_SONNET = "claude-3-5-sonnet" CLAUDE_3_7_SONNET = "claude-3-7-sonnet" + CODESRAL = "codestral" + OPEN_CODESRAL_MAMBA = "open-codestral-mamba" + MISTRAL = "mistral" + MINISTRAL = "ministral" + PIXTRAL = "pixtral" UNKNOWN = "unknown" ANY: TypeAlias = Literal[ + # openai_models "gpt-41", "gpt-45", "gpt-4o", @@ -49,16 +55,25 @@ class ModelFamily: "gpt-4", "gpt-35", "r1", + # google_models "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash", "gemini-2.5-pro", + # anthropic_models "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3-5-haiku", "claude-3-5-sonnet", "claude-3-7-sonnet", + # mistral_models + "codestral", + "open-codestral-mamba", + "mistral", + "ministral", + "pixtral", + # unknown "unknown", ] @@ -98,6 +113,16 @@ class ModelFamily: ModelFamily.GPT_35, ) + @staticmethod + def is_mistral(family: str) -> bool: + return family in ( + ModelFamily.CODESRAL, + ModelFamily.OPEN_CODESRAL_MAMBA, + ModelFamily.MISTRAL, + ModelFamily.MINISTRAL, + ModelFamily.PIXTRAL, + ) + @deprecated("Use the ModelInfo class instead ModelCapabilities.") class ModelCapabilities(TypedDict, total=False): diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_message_transform.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_message_transform.py index 04bd93718..f5ddc08dd 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_message_transform.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_message_transform.py @@ -275,7 +275,6 @@ base_system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Di base_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [ _assert_valid_name, - _set_name, _set_role("user"), ] @@ -293,6 +292,7 @@ system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Dict[st single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = ( base_user_transformer_funcs + [ + _set_name, _set_prepend_text_content, ] ) @@ -300,6 +300,7 @@ single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[ multimodal_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = ( base_user_transformer_funcs + [ + _set_name, _set_multimodal_content, ] ) @@ -334,6 +335,19 @@ thought_assistant_transformer_funcs_gemini: List[Callable[[LLMMessage, Dict[str, # === Specific message param functions === +single_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = ( + base_user_transformer_funcs + + [ + _set_prepend_text_content, + ] +) + +multimodal_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = ( + base_user_transformer_funcs + + [ + _set_multimodal_content, + ] +) # === Transformer maps === @@ -359,6 +373,8 @@ assistant_transformer_funcs: Dict[str, List[Callable[[LLMMessage, Dict[str, Any] "tools": tools_assistant_transformer_funcs, "thought": thought_assistant_transformer_funcs, } + + assistant_transformer_constructors: Dict[str, Callable[..., Any]] = { "text": ChatCompletionAssistantMessageParam, "tools": ChatCompletionAssistantMessageParam, @@ -403,6 +419,12 @@ assistant_transformer_funcs_claude: Dict[str, List[Callable[[LLMMessage, Dict[st } +user_transformer_funcs_mistral: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = { + "text": single_user_transformer_funcs_mistral, + "multimodal": multimodal_user_transformer_funcs_mistral, +} + + def function_execution_result_message(message: LLMMessage, context: Dict[str, Any]) -> TrasformerReturnType: assert isinstance(message, FunctionExecutionResultMessage) return [ @@ -466,6 +488,24 @@ __CLAUDE_TRANSFORMER_MAP: TransformerMap = { FunctionExecutionResultMessage: function_execution_result_message, } +__MISTRAL_TRANSFORMER_MAP: TransformerMap = { + SystemMessage: build_transformer_func( + funcs=system_message_transformers + [_set_empty_to_whitespace], + message_param_func=ChatCompletionSystemMessageParam, + ), + UserMessage: build_conditional_transformer_func( + funcs_map=user_transformer_funcs_mistral, + message_param_func_map=user_transformer_constructors, + condition_func=user_condition, + ), + AssistantMessage: build_conditional_transformer_func( + funcs_map=assistant_transformer_funcs, + message_param_func_map=assistant_transformer_constructors, + condition_func=assistant_condition, + ), + FunctionExecutionResultMessage: function_execution_result_message, +} + # set openai models to use the transformer map total_models = get_args(ModelFamily.ANY) @@ -475,7 +515,11 @@ __claude_models = [model for model in total_models if ModelFamily.is_claude(mode __gemini_models = [model for model in total_models if ModelFamily.is_gemini(model)] -__unknown_models = list(set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models)) +__mistral_models = [model for model in total_models if ModelFamily.is_mistral(model)] + +__unknown_models = list( + set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models) - set(__mistral_models) +) for model in __openai_models: register_transformer("openai", model, __BASE_TRANSFORMER_MAP) @@ -486,6 +530,9 @@ for model in __claude_models: for model in __gemini_models: register_transformer("openai", model, __GEMINI_TRANSFORMER_MAP) +for model in __mistral_models: + register_transformer("openai", model, __MISTRAL_TRANSFORMER_MAP) + for model in __unknown_models: register_transformer("openai", model, __BASE_TRANSFORMER_MAP) diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index a54cffb77..1f824274e 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -2485,4 +2485,17 @@ async def test_multimodal_message_test( _ = await ocr_agent.run(task=multi_modal_message) +@pytest.mark.asyncio +async def test_mistral_remove_name() -> None: + # Test that the name pramaeter is removed from the message + # when the model is Mistral + message = UserMessage(content="foo", source="user") + params = to_oai_type(message, prepend_name=False, model="mistral-7b", model_family=ModelFamily.MISTRAL) + assert ("name" in params[0]) is False + + # when the model is gpt-4o, the name parameter is not removed + params = to_oai_type(message, prepend_name=False, model="gpt-4o", model_family=ModelFamily.GPT_4O) + assert ("name" in params[0]) is True + + # TODO: add integration tests for Azure OpenAI using AAD token.