mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 14:38:50 +00:00
FIX/mistral could not recive name field (#6503)
## Why are these changes needed? FIX/mistral could not recive name field, so add model transformer for mistral ## Related issue number Closes #6147 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
177211b5b4
commit
978cbd2e89
@ -37,9 +37,15 @@ class ModelFamily:
|
||||
CLAUDE_3_5_HAIKU = "claude-3-5-haiku"
|
||||
CLAUDE_3_5_SONNET = "claude-3-5-sonnet"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet"
|
||||
CODESRAL = "codestral"
|
||||
OPEN_CODESRAL_MAMBA = "open-codestral-mamba"
|
||||
MISTRAL = "mistral"
|
||||
MINISTRAL = "ministral"
|
||||
PIXTRAL = "pixtral"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
ANY: TypeAlias = Literal[
|
||||
# openai_models
|
||||
"gpt-41",
|
||||
"gpt-45",
|
||||
"gpt-4o",
|
||||
@ -49,16 +55,25 @@ class ModelFamily:
|
||||
"gpt-4",
|
||||
"gpt-35",
|
||||
"r1",
|
||||
# google_models
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-pro",
|
||||
# anthropic_models
|
||||
"claude-3-haiku",
|
||||
"claude-3-sonnet",
|
||||
"claude-3-opus",
|
||||
"claude-3-5-haiku",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-7-sonnet",
|
||||
# mistral_models
|
||||
"codestral",
|
||||
"open-codestral-mamba",
|
||||
"mistral",
|
||||
"ministral",
|
||||
"pixtral",
|
||||
# unknown
|
||||
"unknown",
|
||||
]
|
||||
|
||||
@ -98,6 +113,16 @@ class ModelFamily:
|
||||
ModelFamily.GPT_35,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_mistral(family: str) -> bool:
|
||||
return family in (
|
||||
ModelFamily.CODESRAL,
|
||||
ModelFamily.OPEN_CODESRAL_MAMBA,
|
||||
ModelFamily.MISTRAL,
|
||||
ModelFamily.MINISTRAL,
|
||||
ModelFamily.PIXTRAL,
|
||||
)
|
||||
|
||||
|
||||
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
|
||||
class ModelCapabilities(TypedDict, total=False):
|
||||
|
||||
@ -275,7 +275,6 @@ base_system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Di
|
||||
|
||||
base_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [
|
||||
_assert_valid_name,
|
||||
_set_name,
|
||||
_set_role("user"),
|
||||
]
|
||||
|
||||
@ -293,6 +292,7 @@ system_message_transformers: List[Callable[[LLMMessage, Dict[str, Any]], Dict[st
|
||||
single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_user_transformer_funcs
|
||||
+ [
|
||||
_set_name,
|
||||
_set_prepend_text_content,
|
||||
]
|
||||
)
|
||||
@ -300,6 +300,7 @@ single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[
|
||||
multimodal_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_user_transformer_funcs
|
||||
+ [
|
||||
_set_name,
|
||||
_set_multimodal_content,
|
||||
]
|
||||
)
|
||||
@ -334,6 +335,19 @@ thought_assistant_transformer_funcs_gemini: List[Callable[[LLMMessage, Dict[str,
|
||||
|
||||
|
||||
# === Specific message param functions ===
|
||||
single_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_user_transformer_funcs
|
||||
+ [
|
||||
_set_prepend_text_content,
|
||||
]
|
||||
)
|
||||
|
||||
multimodal_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
|
||||
base_user_transformer_funcs
|
||||
+ [
|
||||
_set_multimodal_content,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# === Transformer maps ===
|
||||
@ -359,6 +373,8 @@ assistant_transformer_funcs: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]
|
||||
"tools": tools_assistant_transformer_funcs,
|
||||
"thought": thought_assistant_transformer_funcs,
|
||||
}
|
||||
|
||||
|
||||
assistant_transformer_constructors: Dict[str, Callable[..., Any]] = {
|
||||
"text": ChatCompletionAssistantMessageParam,
|
||||
"tools": ChatCompletionAssistantMessageParam,
|
||||
@ -403,6 +419,12 @@ assistant_transformer_funcs_claude: Dict[str, List[Callable[[LLMMessage, Dict[st
|
||||
}
|
||||
|
||||
|
||||
user_transformer_funcs_mistral: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
|
||||
"text": single_user_transformer_funcs_mistral,
|
||||
"multimodal": multimodal_user_transformer_funcs_mistral,
|
||||
}
|
||||
|
||||
|
||||
def function_execution_result_message(message: LLMMessage, context: Dict[str, Any]) -> TrasformerReturnType:
|
||||
assert isinstance(message, FunctionExecutionResultMessage)
|
||||
return [
|
||||
@ -466,6 +488,24 @@ __CLAUDE_TRANSFORMER_MAP: TransformerMap = {
|
||||
FunctionExecutionResultMessage: function_execution_result_message,
|
||||
}
|
||||
|
||||
__MISTRAL_TRANSFORMER_MAP: TransformerMap = {
|
||||
SystemMessage: build_transformer_func(
|
||||
funcs=system_message_transformers + [_set_empty_to_whitespace],
|
||||
message_param_func=ChatCompletionSystemMessageParam,
|
||||
),
|
||||
UserMessage: build_conditional_transformer_func(
|
||||
funcs_map=user_transformer_funcs_mistral,
|
||||
message_param_func_map=user_transformer_constructors,
|
||||
condition_func=user_condition,
|
||||
),
|
||||
AssistantMessage: build_conditional_transformer_func(
|
||||
funcs_map=assistant_transformer_funcs,
|
||||
message_param_func_map=assistant_transformer_constructors,
|
||||
condition_func=assistant_condition,
|
||||
),
|
||||
FunctionExecutionResultMessage: function_execution_result_message,
|
||||
}
|
||||
|
||||
|
||||
# set openai models to use the transformer map
|
||||
total_models = get_args(ModelFamily.ANY)
|
||||
@ -475,7 +515,11 @@ __claude_models = [model for model in total_models if ModelFamily.is_claude(mode
|
||||
|
||||
__gemini_models = [model for model in total_models if ModelFamily.is_gemini(model)]
|
||||
|
||||
__unknown_models = list(set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models))
|
||||
__mistral_models = [model for model in total_models if ModelFamily.is_mistral(model)]
|
||||
|
||||
__unknown_models = list(
|
||||
set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models) - set(__mistral_models)
|
||||
)
|
||||
|
||||
for model in __openai_models:
|
||||
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
|
||||
@ -486,6 +530,9 @@ for model in __claude_models:
|
||||
for model in __gemini_models:
|
||||
register_transformer("openai", model, __GEMINI_TRANSFORMER_MAP)
|
||||
|
||||
for model in __mistral_models:
|
||||
register_transformer("openai", model, __MISTRAL_TRANSFORMER_MAP)
|
||||
|
||||
for model in __unknown_models:
|
||||
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
|
||||
|
||||
|
||||
@ -2485,4 +2485,17 @@ async def test_multimodal_message_test(
|
||||
_ = await ocr_agent.run(task=multi_modal_message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mistral_remove_name() -> None:
|
||||
# Test that the name pramaeter is removed from the message
|
||||
# when the model is Mistral
|
||||
message = UserMessage(content="foo", source="user")
|
||||
params = to_oai_type(message, prepend_name=False, model="mistral-7b", model_family=ModelFamily.MISTRAL)
|
||||
assert ("name" in params[0]) is False
|
||||
|
||||
# when the model is gpt-4o, the name parameter is not removed
|
||||
params = to_oai_type(message, prepend_name=False, model="gpt-4o", model_family=ModelFamily.GPT_4O)
|
||||
assert ("name" in params[0]) is True
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user