Fix/transformer aware any modelfamily (#6213)

This PR improves fallback safety when an invalid `model_family` is
supplied to `get_transformer()`. Previously, if a user passed an
arbitrary or incorrect `family` string in `model_info`, the lookup could
fail without falling back to `ModelFamily.UNKNOWN`.

Now, we explicitly check whether `model_family` is a valid value in
`ModelFamily.ANY`. If not, we fallback to `_find_model_family()` as
intended.


## Related issue number

Related #6011#issuecomment-2779957730

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
EeS 2025-04-06 11:58:16 +09:00 committed by GitHub
parent faf2a4e6ff
commit b24df29ad0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 8 deletions

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, get_args
from autogen_core.models import LLMMessage, ModelFamily
@ -87,10 +87,13 @@ def _find_model_family(api: str, model: str) -> str:
Finds the best matching model family for the given model.
Search via prefix matching (e.g. "gpt-4o" "gpt-4o-1.0").
"""
len_family = 0
family = ModelFamily.UNKNOWN
for _family in MESSAGE_TRANSFORMERS[api].keys():
if model.startswith(_family):
family = _family
if len(_family) > len_family:
family = _family
len_family = len(_family)
return family
@ -108,13 +111,14 @@ def get_transformer(api: str, model: str, model_family: str) -> TransformerMap:
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
"""
if model_family == ModelFamily.UNKNOWN:
if model_family not in set(get_args(ModelFamily.ANY)) or model_family == ModelFamily.UNKNOWN:
# fallback to finding the best matching model family
model_family = _find_model_family(api, model)
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})
if not transformer:
# Just in case, we should never reach here
raise ValueError(f"No transformer found for model family '{model_family}'")
return transformer

View File

@ -30,6 +30,7 @@ from autogen_ext.models.openai._openai_client import (
to_oai_type,
)
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
from autogen_ext.models.openai._transformation.registry import _find_model_family # pyright: ignore[reportPrivateUsage]
from openai.resources.beta.chat.completions import ( # type: ignore
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
)
@ -2394,11 +2395,6 @@ def test_openai_model_registry_find_well() -> None:
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)
def test_openai_model_registry_find_wrong() -> None:
with pytest.raises(ValueError, match="No transformer found for model family"):
get_transformer("openai", "gpt-7", "foobar")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
@ -2451,4 +2447,13 @@ def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
assert result[-1].content == "foobar"
def test_find_model_family() -> None:
assert _find_model_family("openai", "gpt-4") == ModelFamily.GPT_4
assert _find_model_family("openai", "gpt-4-latest") == ModelFamily.GPT_4
assert _find_model_family("openai", "gpt-4o") == ModelFamily.GPT_4O
assert _find_model_family("openai", "gemini-2.0-flash") == ModelFamily.GEMINI_2_0_FLASH
assert _find_model_family("openai", "claude-3-5-haiku-20241022") == ModelFamily.CLAUDE_3_5_HAIKU
assert _find_model_family("openai", "error") == ModelFamily.UNKNOWN
# TODO: add integration tests for Azure OpenAI using AAD token.