graphrag/tests/mock_provider.py

126 lines
3.7 KiB
Python
Raw Permalink Normal View History

# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing mock model provider definitions."""
from collections.abc import AsyncGenerator, Generator
from typing import Any
from pydantic import BaseModel
from graphrag.config.enums import ModelType
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.response.base import (
BaseModelOutput,
BaseModelResponse,
ModelResponse,
)
class MockChatLLM:
"""A mock chat LLM provider."""
def __init__(
self,
responses: list[str | BaseModel] | None = None,
config: LanguageModelConfig | None = None,
json: bool = False,
**kwargs: Any,
):
self.responses = config.responses if config and config.responses else responses
self.response_index = 0
self.config = config or LanguageModelConfig(
type=ModelType.MockChat, model="gpt-4o", api_key="mock"
)
async def achat(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> ModelResponse:
"""Return the next response in the list."""
return self.chat(prompt, history, **kwargs)
async def achat_stream(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> AsyncGenerator[str, None]:
"""Return the next response in the list."""
if not self.responses:
return
for response in self.responses:
response = (
response.model_dump_json()
if isinstance(response, BaseModel)
else response
)
yield response
def chat(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> ModelResponse:
"""Return the next response in the list."""
if not self.responses:
return BaseModelResponse(output=BaseModelOutput(content=""))
response = self.responses[self.response_index % len(self.responses)]
self.response_index += 1
parsed_json = response if isinstance(response, BaseModel) else None
response = (
response.model_dump_json() if isinstance(response, BaseModel) else response
)
return BaseModelResponse(
output=BaseModelOutput(content=response),
parsed_response=parsed_json,
)
def chat_stream(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> Generator[str, None]:
"""Return the next response in the list."""
raise NotImplementedError
class MockEmbeddingLLM:
"""A mock embedding LLM provider."""
def __init__(self, **kwargs: Any):
self.config = LanguageModelConfig(
type=ModelType.MockEmbedding, model="text-embedding-ada-002", api_key="mock"
)
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""Generate an embedding for the input text."""
if isinstance(text_list, str):
return [[1.0, 1.0, 1.0]]
return [[1.0, 1.0, 1.0] for _ in text_list]
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""Generate an embedding for the input text."""
return [1.0, 1.0, 1.0]
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""Generate an embedding for the input text."""
return [1.0, 1.0, 1.0]
async def aembed_batch(
self, text_list: list[str], **kwargs: Any
) -> list[list[float]]:
"""Generate an embedding for the input text."""
if isinstance(text_list, str):
return [[1.0, 1.0, 1.0]]
return [[1.0, 1.0, 1.0] for _ in text_list]