From 9cc2e86071de9995682ccf79f96df481d62cb5cf Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 25 Jun 2025 11:48:12 -0700 Subject: [PATCH] Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config (#620) * Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config * format --- .../llm_client/azure_openai_client.py | 88 +++---- .../llm_client/openai_base_client.py | 217 ++++++++++++++++++ graphiti_core/llm_client/openai_client.py | 166 +++----------- mcp_server/README.md | 8 +- mcp_server/graphiti_mcp_server.py | 6 +- 5 files changed, 307 insertions(+), 178 deletions(-) create mode 100644 graphiti_core/llm_client/openai_base_client.py diff --git a/graphiti_core/llm_client/azure_openai_client.py b/graphiti_core/llm_client/azure_openai_client.py index 60787145..3d4225c1 100644 --- a/graphiti_core/llm_client/azure_openai_client.py +++ b/graphiti_core/llm_client/azure_openai_client.py @@ -14,60 +14,64 @@ See the License for the specific language governing permissions and limitations under the License. """ -import json import logging -from typing import Any +from typing import ClassVar from openai import AsyncAzureOpenAI from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel -from ..prompts.models import Message -from .client import LLMClient -from .config import LLMConfig, ModelSize +from .config import DEFAULT_MAX_TOKENS, LLMConfig +from .openai_base_client import BaseOpenAIClient logger = logging.getLogger(__name__) -class AzureOpenAILLMClient(LLMClient): +class AzureOpenAILLMClient(BaseOpenAIClient): """Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface.""" - def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None): - super().__init__(config, cache=False) - self.azure_client = azure_client + # Class-level constants + MAX_RETRIES: ClassVar[int] = 2 - async def _generate_response( + def __init__( self, - messages: list[Message], + azure_client: AsyncAzureOpenAI, + config: LLMConfig | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + ): + super().__init__(config, cache=False, max_tokens=max_tokens) + self.client = azure_client + + async def _create_structured_completion( + self, + model: str, + messages: list[ChatCompletionMessageParam], + temperature: float | None, + max_tokens: int, + response_model: type[BaseModel], + ): + """Create a structured completion using Azure OpenAI's beta parse API.""" + return await self.client.beta.chat.completions.parse( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format=response_model, # type: ignore + ) + + async def _create_completion( + self, + model: str, + messages: list[ChatCompletionMessageParam], + temperature: float | None, + max_tokens: int, response_model: type[BaseModel] | None = None, - max_tokens: int = 1024, - model_size: ModelSize = ModelSize.medium, - ) -> dict[str, Any]: - """Generate response using Azure OpenAI client.""" - # Convert messages to OpenAI format - openai_messages: list[ChatCompletionMessageParam] = [] - for message in messages: - message.content = self._clean_input(message.content) - if message.role == 'user': - openai_messages.append({'role': 'user', 'content': message.content}) - elif message.role == 'system': - openai_messages.append({'role': 'system', 'content': message.content}) - - # Ensure model is a string - model_name = self.model if self.model else 'gpt-4o-mini' - - try: - response = await self.azure_client.chat.completions.create( - model=model_name, - messages=openai_messages, - temperature=float(self.temperature) if self.temperature is not None else 0.7, - max_tokens=max_tokens, - response_format={'type': 'json_object'}, - ) - result = response.choices[0].message.content or '{}' - - # Parse JSON response - return json.loads(result) - except Exception as e: - logger.error(f'Error in Azure OpenAI LLM response: {e}') - raise + ): + """Create a regular completion with JSON format using Azure OpenAI.""" + return await self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format={'type': 'json_object'}, + ) diff --git a/graphiti_core/llm_client/openai_base_client.py b/graphiti_core/llm_client/openai_base_client.py new file mode 100644 index 00000000..86bcde10 --- /dev/null +++ b/graphiti_core/llm_client/openai_base_client.py @@ -0,0 +1,217 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import logging +import typing +from abc import abstractmethod +from typing import Any, ClassVar + +import openai +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel + +from ..prompts.models import Message +from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient +from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize +from .errors import RateLimitError, RefusalError + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL = 'gpt-4.1-mini' +DEFAULT_SMALL_MODEL = 'gpt-4.1-nano' + + +class BaseOpenAIClient(LLMClient): + """ + Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI). + + This class contains shared logic for both OpenAI and Azure OpenAI clients, + reducing code duplication while allowing for implementation-specific differences. + """ + + # Class-level constants + MAX_RETRIES: ClassVar[int] = 2 + + def __init__( + self, + config: LLMConfig | None = None, + cache: bool = False, + max_tokens: int = DEFAULT_MAX_TOKENS, + ): + if cache: + raise NotImplementedError('Caching is not implemented for OpenAI-based clients') + + if config is None: + config = LLMConfig() + + super().__init__(config, cache) + self.max_tokens = max_tokens + + @abstractmethod + async def _create_completion( + self, + model: str, + messages: list[ChatCompletionMessageParam], + temperature: float | None, + max_tokens: int, + response_model: type[BaseModel] | None = None, + ) -> Any: + """Create a completion using the specific client implementation.""" + pass + + @abstractmethod + async def _create_structured_completion( + self, + model: str, + messages: list[ChatCompletionMessageParam], + temperature: float | None, + max_tokens: int, + response_model: type[BaseModel], + ) -> Any: + """Create a structured completion using the specific client implementation.""" + pass + + def _convert_messages_to_openai_format( + self, messages: list[Message] + ) -> list[ChatCompletionMessageParam]: + """Convert internal Message format to OpenAI ChatCompletionMessageParam format.""" + openai_messages: list[ChatCompletionMessageParam] = [] + for m in messages: + m.content = self._clean_input(m.content) + if m.role == 'user': + openai_messages.append({'role': 'user', 'content': m.content}) + elif m.role == 'system': + openai_messages.append({'role': 'system', 'content': m.content}) + return openai_messages + + def _get_model_for_size(self, model_size: ModelSize) -> str: + """Get the appropriate model name based on the requested size.""" + if model_size == ModelSize.small: + return self.small_model or DEFAULT_SMALL_MODEL + else: + return self.model or DEFAULT_MODEL + + def _handle_structured_response(self, response: Any) -> dict[str, Any]: + """Handle structured response parsing and validation.""" + response_object = response.choices[0].message + + if response_object.parsed: + return response_object.parsed.model_dump() + elif response_object.refusal: + raise RefusalError(response_object.refusal) + else: + raise Exception(f'Invalid response from LLM: {response_object.model_dump()}') + + def _handle_json_response(self, response: Any) -> dict[str, Any]: + """Handle JSON response parsing.""" + result = response.choices[0].message.content or '{}' + return json.loads(result) + + async def _generate_response( + self, + messages: list[Message], + response_model: type[BaseModel] | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + model_size: ModelSize = ModelSize.medium, + ) -> dict[str, Any]: + """Generate a response using the appropriate client implementation.""" + openai_messages = self._convert_messages_to_openai_format(messages) + model = self._get_model_for_size(model_size) + + try: + if response_model: + response = await self._create_structured_completion( + model=model, + messages=openai_messages, + temperature=self.temperature, + max_tokens=max_tokens or self.max_tokens, + response_model=response_model, + ) + return self._handle_structured_response(response) + else: + response = await self._create_completion( + model=model, + messages=openai_messages, + temperature=self.temperature, + max_tokens=max_tokens or self.max_tokens, + ) + return self._handle_json_response(response) + + except openai.LengthFinishReasonError as e: + raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e + except openai.RateLimitError as e: + raise RateLimitError from e + except Exception as e: + logger.error(f'Error in generating LLM response: {e}') + raise + + async def generate_response( + self, + messages: list[Message], + response_model: type[BaseModel] | None = None, + max_tokens: int | None = None, + model_size: ModelSize = ModelSize.medium, + ) -> dict[str, typing.Any]: + """Generate a response with retry logic and error handling.""" + if max_tokens is None: + max_tokens = self.max_tokens + + retry_count = 0 + last_error = None + + # Add multilingual extraction instructions + messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES + + while retry_count <= self.MAX_RETRIES: + try: + response = await self._generate_response( + messages, response_model, max_tokens, model_size + ) + return response + except (RateLimitError, RefusalError): + # These errors should not trigger retries + raise + except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError): + # Let OpenAI's client handle these retries + raise + except Exception as e: + last_error = e + + # Don't retry if we've hit the max retries + if retry_count >= self.MAX_RETRIES: + logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}') + raise + + retry_count += 1 + + # Construct a detailed error message for the LLM + error_context = ( + f'The previous response attempt was invalid. ' + f'Error type: {e.__class__.__name__}. ' + f'Error details: {str(e)}. ' + f'Please try again with a valid response, ensuring the output matches ' + f'the expected format and constraints.' + ) + + error_message = Message(role='user', content=error_context) + messages.append(error_message) + logger.warning( + f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}' + ) + + # If we somehow get here, raise the last error + raise last_error or Exception('Max retries exceeded with no specific error') diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index b627576f..b47f8e9e 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -14,50 +14,27 @@ See the License for the specific language governing permissions and limitations under the License. """ -import logging import typing -from typing import ClassVar -import openai from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel -from ..prompts.models import Message -from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient -from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize -from .errors import RateLimitError, RefusalError - -logger = logging.getLogger(__name__) - -DEFAULT_MODEL = 'gpt-4.1-mini' -DEFAULT_SMALL_MODEL = 'gpt-4.1-nano' +from .config import DEFAULT_MAX_TOKENS, LLMConfig +from .openai_base_client import BaseOpenAIClient -class OpenAIClient(LLMClient): +class OpenAIClient(BaseOpenAIClient): """ OpenAIClient is a client class for interacting with OpenAI's language models. - This class extends the LLMClient and provides methods to initialize the client, - get an embedder, and generate responses from the language model. + This class extends the BaseOpenAIClient and provides OpenAI-specific implementation + for creating completions. Attributes: client (AsyncOpenAI): The OpenAI client used to interact with the API. - model (str): The model name to use for generating responses. - temperature (float): The temperature to use for generating responses. - max_tokens (int): The maximum number of tokens to generate in a response. - - Methods: - __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None): - Initializes the OpenAIClient with the provided configuration, cache setting, and client. - - _generate_response(messages: list[Message]) -> dict[str, typing.Any]: - Generates a response from the language model based on the provided messages. """ - # Class-level constants - MAX_RETRIES: ClassVar[int] = 2 - def __init__( self, config: LLMConfig | None = None, @@ -72,120 +49,47 @@ class OpenAIClient(LLMClient): config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. cache (bool): Whether to use caching for responses. Defaults to False. client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. - """ - # removed caching to simplify the `generate_response` override - if cache: - raise NotImplementedError('Caching is not implemented for OpenAI') + super().__init__(config, cache, max_tokens) if config is None: config = LLMConfig() - super().__init__(config, cache) - if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) else: self.client = client - self.max_tokens = max_tokens - - async def _generate_response( + async def _create_structured_completion( self, - messages: list[Message], - response_model: type[BaseModel] | None = None, - max_tokens: int = DEFAULT_MAX_TOKENS, - model_size: ModelSize = ModelSize.medium, - ) -> dict[str, typing.Any]: - openai_messages: list[ChatCompletionMessageParam] = [] - for m in messages: - m.content = self._clean_input(m.content) - if m.role == 'user': - openai_messages.append({'role': 'user', 'content': m.content}) - elif m.role == 'system': - openai_messages.append({'role': 'system', 'content': m.content}) - try: - if model_size == ModelSize.small: - model = self.small_model or DEFAULT_SMALL_MODEL - else: - model = self.model or DEFAULT_MODEL + model: str, + messages: list[ChatCompletionMessageParam], + temperature: float | None, + max_tokens: int, + response_model: type[BaseModel], + ): + """Create a structured completion using OpenAI's beta parse API.""" + return await self.client.beta.chat.completions.parse( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format=response_model, # type: ignore + ) - response = await self.client.beta.chat.completions.parse( - model=model, - messages=openai_messages, - temperature=self.temperature, - max_tokens=max_tokens or self.max_tokens, - response_format=response_model, # type: ignore - ) - - response_object = response.choices[0].message - - if response_object.parsed: - return response_object.parsed.model_dump() - elif response_object.refusal: - raise RefusalError(response_object.refusal) - else: - raise Exception(f'Invalid response from LLM: {response_object.model_dump()}') - except openai.LengthFinishReasonError as e: - raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e - except openai.RateLimitError as e: - raise RateLimitError from e - except Exception as e: - logger.error(f'Error in generating LLM response: {e}') - raise - - async def generate_response( + async def _create_completion( self, - messages: list[Message], + model: str, + messages: list[ChatCompletionMessageParam], + temperature: float | None, + max_tokens: int, response_model: type[BaseModel] | None = None, - max_tokens: int | None = None, - model_size: ModelSize = ModelSize.medium, - ) -> dict[str, typing.Any]: - if max_tokens is None: - max_tokens = self.max_tokens - - retry_count = 0 - last_error = None - - # Add multilingual extraction instructions - messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES - - while retry_count <= self.MAX_RETRIES: - try: - response = await self._generate_response( - messages, response_model, max_tokens, model_size - ) - return response - except (RateLimitError, RefusalError): - # These errors should not trigger retries - raise - except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError): - # Let OpenAI's client handle these retries - raise - except Exception as e: - last_error = e - - # Don't retry if we've hit the max retries - if retry_count >= self.MAX_RETRIES: - logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}') - raise - - retry_count += 1 - - # Construct a detailed error message for the LLM - error_context = ( - f'The previous response attempt was invalid. ' - f'Error type: {e.__class__.__name__}. ' - f'Error details: {str(e)}. ' - f'Please try again with a valid response, ensuring the output matches ' - f'the expected format and constraints.' - ) - - error_message = Message(role='user', content=error_context) - messages.append(error_message) - logger.warning( - f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}' - ) - - # If we somehow get here, raise the last error - raise last_error or Exception('Max retries exceeded with no specific error') + ): + """Create a regular completion with JSON format.""" + return await self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format={'type': 'json_object'}, + ) diff --git a/mcp_server/README.md b/mcp_server/README.md index 4ef4c89f..395d4db2 100644 --- a/mcp_server/README.md +++ b/mcp_server/README.md @@ -78,9 +78,11 @@ The server uses the following environment variables: - `MODEL_NAME`: OpenAI model name to use for LLM operations. - `SMALL_MODEL_NAME`: OpenAI model name to use for smaller LLM operations. - `LLM_TEMPERATURE`: Temperature for LLM responses (0.0-2.0). -- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI endpoint URL -- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI deployment name -- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI API version +- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI LLM endpoint URL +- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI LLM deployment name +- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI LLM API version +- `AZURE_OPENAI_EMBEDDING_API_KEY`: Optional Azure OpenAI Embedding deployment key (if other than `OPENAI_API_KEY`) +- `AZURE_OPENAI_EMBEDDING_ENDPOINT`: Optional Azure OpenAI Embedding endpoint URL - `AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME`: Optional Azure OpenAI embedding deployment name - `AZURE_OPENAI_EMBEDDING_API_VERSION`: Optional Azure OpenAI API version - `AZURE_OPENAI_USE_MANAGED_IDENTITY`: Optional use Azure Managed Identities for authentication diff --git a/mcp_server/graphiti_mcp_server.py b/mcp_server/graphiti_mcp_server.py index f79d3d04..35d5182b 100644 --- a/mcp_server/graphiti_mcp_server.py +++ b/mcp_server/graphiti_mcp_server.py @@ -367,7 +367,7 @@ class GraphitiEmbedderConfig(BaseModel): model_env = os.environ.get('EMBEDDER_MODEL_NAME', '') model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL - azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None) + azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None) azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None) azure_openai_deployment_name = os.environ.get( 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None @@ -390,7 +390,9 @@ class GraphitiEmbedderConfig(BaseModel): if not azure_openai_use_managed_identity: # api key - api_key = os.environ.get('OPENAI_API_KEY', None) + api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get( + 'OPENAI_API_KEY', None + ) else: # Managed identity api_key = None