FIX/mistral could not recive name field (#6503)

## Why are these changes needed?
FIX/mistral could not recive name field, so add model transformer for
mistral

## Related issue number
Closes #6147

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
EeS 2025-05-13 11:32:14 +09:00 committed by GitHub
parent 177211b5b4
commit 978cbd2e89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 87 additions and 2 deletions

View File

@ -37,9 +37,15 @@ class ModelFamily:
CLAUDE_3_5_HAIKU = "claude-3-5-haiku"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet"
CODESRAL = "codestral"
OPEN_CODESRAL_MAMBA = "open-codestral-mamba"
MISTRAL = "mistral"
MINISTRAL = "ministral"
PIXTRAL = "pixtral"
UNKNOWN = "unknown"
ANY: TypeAlias = Literal[
# openai_models
"gpt-41",
"gpt-45",
"gpt-4o",
@ -49,16 +55,25 @@ class ModelFamily:
"gpt-4",
"gpt-35",
"r1",
# google_models
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.5-pro",
# anthropic_models
"claude-3-haiku",
"claude-3-sonnet",
"claude-3-opus",
"claude-3-5-haiku",
"claude-3-5-sonnet",
"claude-3-7-sonnet",
# mistral_models
"codestral",
"open-codestral-mamba",
"mistral",
"ministral",
"pixtral",
# unknown
"unknown",
]
@ -98,6 +113,16 @@ class ModelFamily:
ModelFamily.GPT_35,
)
@staticmethod
def is_mistral(family: str) -> bool:
return family in (
ModelFamily.CODESRAL,
ModelFamily.OPEN_CODESRAL_MAMBA,
ModelFamily.MISTRAL,
ModelFamily.MINISTRAL,
ModelFamily.PIXTRAL,
)
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
class ModelCapabilities(TypedDict, total=False):

View File

@ -275,7 +275,6 @@ base_system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Di
base_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [
_assert_valid_name,
_set_name,
_set_role("user"),
]
@ -293,6 +292,7 @@ system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Dict[st
single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
base_user_transformer_funcs
+ [
_set_name,
_set_prepend_text_content,
]
)
@ -300,6 +300,7 @@ single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[
multimodal_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
base_user_transformer_funcs
+ [
_set_name,
_set_multimodal_content,
]
)
@ -334,6 +335,19 @@ thought_assistant_transformer_funcs_gemini: List[Callable[[LLMMessage, Dict[str,
# === Specific message param functions ===
single_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
base_user_transformer_funcs
+ [
_set_prepend_text_content,
]
)
multimodal_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
base_user_transformer_funcs
+ [
_set_multimodal_content,
]
)
# === Transformer maps ===
@ -359,6 +373,8 @@ assistant_transformer_funcs: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]
"tools": tools_assistant_transformer_funcs,
"thought": thought_assistant_transformer_funcs,
}
assistant_transformer_constructors: Dict[str, Callable[..., Any]] = {
"text": ChatCompletionAssistantMessageParam,
"tools": ChatCompletionAssistantMessageParam,
@ -403,6 +419,12 @@ assistant_transformer_funcs_claude: Dict[str, List[Callable[[LLMMessage, Dict[st
}
user_transformer_funcs_mistral: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
"text": single_user_transformer_funcs_mistral,
"multimodal": multimodal_user_transformer_funcs_mistral,
}
def function_execution_result_message(message: LLMMessage, context: Dict[str, Any]) -> TrasformerReturnType:
assert isinstance(message, FunctionExecutionResultMessage)
return [
@ -466,6 +488,24 @@ __CLAUDE_TRANSFORMER_MAP: TransformerMap = {
FunctionExecutionResultMessage: function_execution_result_message,
}
__MISTRAL_TRANSFORMER_MAP: TransformerMap = {
SystemMessage: build_transformer_func(
funcs=system_message_transformers + [_set_empty_to_whitespace],
message_param_func=ChatCompletionSystemMessageParam,
),
UserMessage: build_conditional_transformer_func(
funcs_map=user_transformer_funcs_mistral,
message_param_func_map=user_transformer_constructors,
condition_func=user_condition,
),
AssistantMessage: build_conditional_transformer_func(
funcs_map=assistant_transformer_funcs,
message_param_func_map=assistant_transformer_constructors,
condition_func=assistant_condition,
),
FunctionExecutionResultMessage: function_execution_result_message,
}
# set openai models to use the transformer map
total_models = get_args(ModelFamily.ANY)
@ -475,7 +515,11 @@ __claude_models = [model for model in total_models if ModelFamily.is_claude(mode
__gemini_models = [model for model in total_models if ModelFamily.is_gemini(model)]
__unknown_models = list(set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models))
__mistral_models = [model for model in total_models if ModelFamily.is_mistral(model)]
__unknown_models = list(
set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models) - set(__mistral_models)
)
for model in __openai_models:
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
@ -486,6 +530,9 @@ for model in __claude_models:
for model in __gemini_models:
register_transformer("openai", model, __GEMINI_TRANSFORMER_MAP)
for model in __mistral_models:
register_transformer("openai", model, __MISTRAL_TRANSFORMER_MAP)
for model in __unknown_models:
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)

View File

@ -2485,4 +2485,17 @@ async def test_multimodal_message_test(
_ = await ocr_agent.run(task=multi_modal_message)
@pytest.mark.asyncio
async def test_mistral_remove_name() -> None:
# Test that the name pramaeter is removed from the message
# when the model is Mistral
message = UserMessage(content="foo", source="user")
params = to_oai_type(message, prepend_name=False, model="mistral-7b", model_family=ModelFamily.MISTRAL)
assert ("name" in params[0]) is False
# when the model is gpt-4o, the name parameter is not removed
params = to_oai_type(message, prepend_name=False, model="gpt-4o", model_family=ModelFamily.GPT_4O)
assert ("name" in params[0]) is True
# TODO: add integration tests for Azure OpenAI using AAD token.