mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-26 22:48:40 +00:00
Fix/transformer aware any modelfamily (#6213)
This PR improves fallback safety when an invalid `model_family` is supplied to `get_transformer()`. Previously, if a user passed an arbitrary or incorrect `family` string in `model_info`, the lookup could fail without falling back to `ModelFamily.UNKNOWN`. Now, we explicitly check whether `model_family` is a valid value in `ModelFamily.ANY`. If not, we fallback to `_find_model_family()` as intended. ## Related issue number Related #6011#issuecomment-2779957730 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
faf2a4e6ff
commit
b24df29ad0
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable, Dict, List, get_args
|
||||
|
||||
from autogen_core.models import LLMMessage, ModelFamily
|
||||
|
||||
@ -87,10 +87,13 @@ def _find_model_family(api: str, model: str) -> str:
|
||||
Finds the best matching model family for the given model.
|
||||
Search via prefix matching (e.g. "gpt-4o" → "gpt-4o-1.0").
|
||||
"""
|
||||
len_family = 0
|
||||
family = ModelFamily.UNKNOWN
|
||||
for _family in MESSAGE_TRANSFORMERS[api].keys():
|
||||
if model.startswith(_family):
|
||||
family = _family
|
||||
if len(_family) > len_family:
|
||||
family = _family
|
||||
len_family = len(_family)
|
||||
return family
|
||||
|
||||
|
||||
@ -108,13 +111,14 @@ def get_transformer(api: str, model: str, model_family: str) -> TransformerMap:
|
||||
Keeping this as a function (instead of direct dict access) improves long-term flexibility.
|
||||
"""
|
||||
|
||||
if model_family == ModelFamily.UNKNOWN:
|
||||
if model_family not in set(get_args(ModelFamily.ANY)) or model_family == ModelFamily.UNKNOWN:
|
||||
# fallback to finding the best matching model family
|
||||
model_family = _find_model_family(api, model)
|
||||
|
||||
transformer = MESSAGE_TRANSFORMERS.get(api, {}).get(model_family, {})
|
||||
|
||||
if not transformer:
|
||||
# Just in case, we should never reach here
|
||||
raise ValueError(f"No transformer found for model family '{model_family}'")
|
||||
|
||||
return transformer
|
||||
|
||||
@ -30,6 +30,7 @@ from autogen_ext.models.openai._openai_client import (
|
||||
to_oai_type,
|
||||
)
|
||||
from autogen_ext.models.openai._transformation import TransformerMap, get_transformer
|
||||
from autogen_ext.models.openai._transformation.registry import _find_model_family # pyright: ignore[reportPrivateUsage]
|
||||
from openai.resources.beta.chat.completions import ( # type: ignore
|
||||
AsyncChatCompletionStreamManager as BetaAsyncChatCompletionStreamManager, # type: ignore
|
||||
)
|
||||
@ -2394,11 +2395,6 @@ def test_openai_model_registry_find_well() -> None:
|
||||
assert get_regitered_transformer(client1) == get_regitered_transformer(client2)
|
||||
|
||||
|
||||
def test_openai_model_registry_find_wrong() -> None:
|
||||
with pytest.raises(ValueError, match="No transformer found for model family"):
|
||||
get_transformer("openai", "gpt-7", "foobar")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
@ -2451,4 +2447,13 @@ def test_rstrip_railing_whitespace_at_last_assistant_content() -> None:
|
||||
assert result[-1].content == "foobar"
|
||||
|
||||
|
||||
def test_find_model_family() -> None:
|
||||
assert _find_model_family("openai", "gpt-4") == ModelFamily.GPT_4
|
||||
assert _find_model_family("openai", "gpt-4-latest") == ModelFamily.GPT_4
|
||||
assert _find_model_family("openai", "gpt-4o") == ModelFamily.GPT_4O
|
||||
assert _find_model_family("openai", "gemini-2.0-flash") == ModelFamily.GEMINI_2_0_FLASH
|
||||
assert _find_model_family("openai", "claude-3-5-haiku-20241022") == ModelFamily.CLAUDE_3_5_HAIKU
|
||||
assert _find_model_family("openai", "error") == ModelFamily.UNKNOWN
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user