Fix/model provider key injection check (#1799)

* Check available models for type validation

* Semver

* Fix ruff and pyright

* Apply feedback
This commit is contained in:
Alonso Guevara 2025-03-11 17:48:30 -06:00 committed by GitHub
parent e39d869bed
commit 53950f8442
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 89 additions and 30 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add check for custom model tyoes while config loading"
}

View File

@ -15,6 +15,7 @@ from graphrag.config.errors import (
AzureDeploymentNameMissingError,
ConflictingSettingsError,
)
from graphrag.language_model.factory import ModelFactory
class LanguageModelConfig(BaseModel):
@ -44,7 +45,7 @@ class LanguageModelConfig(BaseModel):
self.api_key is None or self.api_key.strip() == ""
):
raise ApiKeyMissingError(
self.type.value,
self.type,
self.auth_type.value,
)
@ -73,10 +74,24 @@ class LanguageModelConfig(BaseModel):
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
):
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
raise ConflictingSettingsError(msg)
type: ModelType = Field(description="The type of LLM model to use.")
type: ModelType | str = Field(description="The type of LLM model to use.")
def _validate_type(self) -> None:
"""Validate the model type.
Raises
------
KeyError
If the model name is not recognized.
"""
# Type should be contained by the registered models
if not ModelFactory.is_supported_model(self.type):
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
raise KeyError(msg)
model: str = Field(description="The LLM model to use.")
encoding_model: str = Field(
description="The encoding model to use",
@ -141,7 +156,7 @@ class LanguageModelConfig(BaseModel):
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
) and (self.api_base is None or self.api_base.strip() == ""):
raise AzureApiBaseMissingError(self.type.value)
raise AzureApiBaseMissingError(self.type)
api_version: str | None = Field(
description="The version of the LLM API to use.",
@ -162,7 +177,7 @@ class LanguageModelConfig(BaseModel):
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
) and (self.api_version is None or self.api_version.strip() == ""):
raise AzureApiVersionMissingError(self.type.value)
raise AzureApiVersionMissingError(self.type)
deployment_name: str | None = Field(
description="The deployment name to use for the LLM service.",
@ -183,7 +198,7 @@ class LanguageModelConfig(BaseModel):
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type.value)
raise AzureDeploymentNameMissingError(self.type)
organization: str | None = Field(
description="The organization to use for the LLM service.",
@ -251,6 +266,7 @@ class LanguageModelConfig(BaseModel):
@model_validator(mode="after")
def _validate_model(self):
self._validate_type()
self._validate_auth_type()
self._validate_api_key()
self._validate_azure_settings()

View File

@ -70,6 +70,33 @@ class ModelFactory:
raise ValueError(msg)
return cls._embedding_registry[model_type](**kwargs)
@classmethod
def get_chat_models(cls) -> list[str]:
"""Get the registered ChatModel implementations."""
return list(cls._chat_registry.keys())
@classmethod
def get_embedding_models(cls) -> list[str]:
"""Get the registered EmbeddingModel implementations."""
return list(cls._embedding_registry.keys())
@classmethod
def is_supported_chat_model(cls, model_type: str) -> bool:
"""Check if the given model type is supported."""
return model_type in cls._chat_registry
@classmethod
def is_supported_embedding_model(cls, model_type: str) -> bool:
"""Check if the given model type is supported."""
return model_type in cls._embedding_registry
@classmethod
def is_supported_model(cls, model_type: str) -> bool:
"""Check if the given model type is supported."""
return cls.is_supported_chat_model(
model_type
) or cls.is_supported_embedding_model(model_type)
# --- Register default implementations ---
ModelFactory.register_chat(

View File

@ -8,7 +8,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Generator
from graphrag.language_model.response.base import ModelResponse
@ -143,7 +143,7 @@ class ChatModel(Protocol):
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
) -> Generator[str, None]:
"""
Generate a response for the given text using a streaming interface.

View File

@ -3,21 +3,16 @@
"""A module containing fnllm model provider definitions."""
from collections.abc import AsyncGenerator
from __future__ import annotations
from typing import TYPE_CHECKING
from fnllm.openai import (
create_openai_chat_llm,
create_openai_client,
create_openai_embeddings_llm,
)
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import (
LanguageModelConfig,
)
from graphrag.language_model.providers.fnllm.events import FNLLMEvents
from graphrag.language_model.providers.fnllm.utils import (
_create_cache,
@ -31,6 +26,18 @@ from graphrag.language_model.response.base import (
ModelResponse,
)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import (
LanguageModelConfig,
)
class OpenAIChatFNLLM:
"""An OpenAI Chat Model provider using the fnllm library."""
@ -121,7 +128,7 @@ class OpenAIChatFNLLM:
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs
) -> AsyncGenerator[str, None]:
) -> Generator[str, None]:
"""
Stream Chat with the Model using the given prompt.
@ -319,7 +326,7 @@ class AzureOpenAIChatFNLLM:
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs
) -> AsyncGenerator[str, None]:
) -> Generator[str, None]:
"""
Stream Chat with the Model using the given prompt.

View File

@ -3,24 +3,29 @@
"""A module containing utils for fnllm."""
from __future__ import annotations
import asyncio
import threading
from collections.abc import Coroutine
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar
from fnllm.base.config import JsonStrategy, RetryStrategy
from fnllm.openai import AzureOpenAIConfig, OpenAIConfig, PublicOpenAIConfig
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import (
LanguageModelConfig,
)
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.language_model.providers.fnllm.cache import FNLLMCacheProvider
if TYPE_CHECKING:
from collections.abc import Coroutine
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import (
LanguageModelConfig,
)
from graphrag.index.typing.error_handler import ErrorHandlerFn
def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider | None:
"""Create an FNLLM cache from a pipeline cache."""

View File

@ -6,7 +6,7 @@
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
"""
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Generator
from typing import Any
from graphrag.language_model.factory import ModelFactory
@ -40,7 +40,7 @@ async def test_create_custom_chat_model():
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]: ...
) -> Generator[str, None]: ...
ModelFactory.register_chat("custom_chat", CustomChatModel)
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")

View File

@ -3,7 +3,7 @@
"""A module containing mock model provider definitions."""
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Generator
from typing import Any
from pydantic import BaseModel
@ -85,7 +85,7 @@ class MockChatLLM:
prompt: str,
history: list | None = None,
**kwargs,
) -> AsyncGenerator[str, None]:
) -> Generator[str, None]:
"""Return the next response in the list."""
raise NotImplementedError