mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Fix/model provider key injection check (#1799)
* Check available models for type validation * Semver * Fix ruff and pyright * Apply feedback
This commit is contained in:
parent
e39d869bed
commit
53950f8442
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add check for custom model tyoes while config loading"
|
||||
}
|
@ -15,6 +15,7 @@ from graphrag.config.errors import (
|
||||
AzureDeploymentNameMissingError,
|
||||
ConflictingSettingsError,
|
||||
)
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
|
||||
class LanguageModelConfig(BaseModel):
|
||||
@ -44,7 +45,7 @@ class LanguageModelConfig(BaseModel):
|
||||
self.api_key is None or self.api_key.strip() == ""
|
||||
):
|
||||
raise ApiKeyMissingError(
|
||||
self.type.value,
|
||||
self.type,
|
||||
self.auth_type.value,
|
||||
)
|
||||
|
||||
@ -73,10 +74,24 @@ class LanguageModelConfig(BaseModel):
|
||||
if self.auth_type == AuthType.AzureManagedIdentity and (
|
||||
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
|
||||
):
|
||||
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
|
||||
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
|
||||
raise ConflictingSettingsError(msg)
|
||||
|
||||
type: ModelType = Field(description="The type of LLM model to use.")
|
||||
type: ModelType | str = Field(description="The type of LLM model to use.")
|
||||
|
||||
def _validate_type(self) -> None:
|
||||
"""Validate the model type.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model name is not recognized.
|
||||
"""
|
||||
# Type should be contained by the registered models
|
||||
if not ModelFactory.is_supported_model(self.type):
|
||||
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
|
||||
raise KeyError(msg)
|
||||
|
||||
model: str = Field(description="The LLM model to use.")
|
||||
encoding_model: str = Field(
|
||||
description="The encoding model to use",
|
||||
@ -141,7 +156,7 @@ class LanguageModelConfig(BaseModel):
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
) and (self.api_base is None or self.api_base.strip() == ""):
|
||||
raise AzureApiBaseMissingError(self.type.value)
|
||||
raise AzureApiBaseMissingError(self.type)
|
||||
|
||||
api_version: str | None = Field(
|
||||
description="The version of the LLM API to use.",
|
||||
@ -162,7 +177,7 @@ class LanguageModelConfig(BaseModel):
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
) and (self.api_version is None or self.api_version.strip() == ""):
|
||||
raise AzureApiVersionMissingError(self.type.value)
|
||||
raise AzureApiVersionMissingError(self.type)
|
||||
|
||||
deployment_name: str | None = Field(
|
||||
description="The deployment name to use for the LLM service.",
|
||||
@ -183,7 +198,7 @@ class LanguageModelConfig(BaseModel):
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||
raise AzureDeploymentNameMissingError(self.type.value)
|
||||
raise AzureDeploymentNameMissingError(self.type)
|
||||
|
||||
organization: str | None = Field(
|
||||
description="The organization to use for the LLM service.",
|
||||
@ -251,6 +266,7 @@ class LanguageModelConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
self._validate_type()
|
||||
self._validate_auth_type()
|
||||
self._validate_api_key()
|
||||
self._validate_azure_settings()
|
||||
|
@ -70,6 +70,33 @@ class ModelFactory:
|
||||
raise ValueError(msg)
|
||||
return cls._embedding_registry[model_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_chat_models(cls) -> list[str]:
|
||||
"""Get the registered ChatModel implementations."""
|
||||
return list(cls._chat_registry.keys())
|
||||
|
||||
@classmethod
|
||||
def get_embedding_models(cls) -> list[str]:
|
||||
"""Get the registered EmbeddingModel implementations."""
|
||||
return list(cls._embedding_registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_chat_model(cls, model_type: str) -> bool:
|
||||
"""Check if the given model type is supported."""
|
||||
return model_type in cls._chat_registry
|
||||
|
||||
@classmethod
|
||||
def is_supported_embedding_model(cls, model_type: str) -> bool:
|
||||
"""Check if the given model type is supported."""
|
||||
return model_type in cls._embedding_registry
|
||||
|
||||
@classmethod
|
||||
def is_supported_model(cls, model_type: str) -> bool:
|
||||
"""Check if the given model type is supported."""
|
||||
return cls.is_supported_chat_model(
|
||||
model_type
|
||||
) or cls.is_supported_embedding_model(model_type)
|
||||
|
||||
|
||||
# --- Register default implementations ---
|
||||
ModelFactory.register_chat(
|
||||
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
|
||||
from graphrag.language_model.response.base import ModelResponse
|
||||
|
||||
@ -143,7 +143,7 @@ class ChatModel(Protocol):
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> Generator[str, None]:
|
||||
"""
|
||||
Generate a response for the given text using a streaming interface.
|
||||
|
||||
|
@ -3,21 +3,16 @@
|
||||
|
||||
"""A module containing fnllm model provider definitions."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fnllm.openai import (
|
||||
create_openai_chat_llm,
|
||||
create_openai_client,
|
||||
create_openai_embeddings_llm,
|
||||
)
|
||||
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
|
||||
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.language_model_config import (
|
||||
LanguageModelConfig,
|
||||
)
|
||||
from graphrag.language_model.providers.fnllm.events import FNLLMEvents
|
||||
from graphrag.language_model.providers.fnllm.utils import (
|
||||
_create_cache,
|
||||
@ -31,6 +26,18 @@ from graphrag.language_model.response.base import (
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
|
||||
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
|
||||
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.language_model_config import (
|
||||
LanguageModelConfig,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatFNLLM:
|
||||
"""An OpenAI Chat Model provider using the fnllm library."""
|
||||
@ -121,7 +128,7 @@ class OpenAIChatFNLLM:
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> Generator[str, None]:
|
||||
"""
|
||||
Stream Chat with the Model using the given prompt.
|
||||
|
||||
@ -319,7 +326,7 @@ class AzureOpenAIChatFNLLM:
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> Generator[str, None]:
|
||||
"""
|
||||
Stream Chat with the Model using the given prompt.
|
||||
|
||||
|
@ -3,24 +3,29 @@
|
||||
|
||||
"""A module containing utils for fnllm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from fnllm.base.config import JsonStrategy, RetryStrategy
|
||||
from fnllm.openai import AzureOpenAIConfig, OpenAIConfig, PublicOpenAIConfig
|
||||
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.language_model_config import (
|
||||
LanguageModelConfig,
|
||||
)
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.language_model.providers.fnllm.cache import FNLLMCacheProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Coroutine
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.language_model_config import (
|
||||
LanguageModelConfig,
|
||||
)
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
|
||||
|
||||
def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider | None:
|
||||
"""Create an FNLLM cache from a pipeline cache."""
|
||||
|
@ -6,7 +6,7 @@
|
||||
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
@ -40,7 +40,7 @@ async def test_create_custom_chat_model():
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
) -> Generator[str, None]: ...
|
||||
|
||||
ModelFactory.register_chat("custom_chat", CustomChatModel)
|
||||
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""A module containing mock model provider definitions."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -85,7 +85,7 @@ class MockChatLLM:
|
||||
prompt: str,
|
||||
history: list | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> Generator[str, None]:
|
||||
"""Return the next response in the list."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user