Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config (#620)

* Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config

* format
This commit is contained in:
Daniel Chalef 2025-06-25 11:48:12 -07:00 committed by GitHub
parent 587f1b9876
commit 9cc2e86071
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 307 additions and 178 deletions

View File

@ -14,60 +14,64 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from typing import Any
from typing import ClassVar
from openai import AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig, ModelSize
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .openai_base_client import BaseOpenAIClient
logger = logging.getLogger(__name__)
class AzureOpenAILLMClient(LLMClient):
class AzureOpenAILLMClient(BaseOpenAIClient):
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
super().__init__(config, cache=False)
self.azure_client = azure_client
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
async def _generate_response(
def __init__(
self,
messages: list[Message],
azure_client: AsyncAzureOpenAI,
config: LLMConfig | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
super().__init__(config, cache=False, max_tokens=max_tokens)
self.client = azure_client
async def _create_structured_completion(
self,
model: str,
messages: list[ChatCompletionMessageParam],
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel],
):
"""Create a structured completion using Azure OpenAI's beta parse API."""
return await self.client.beta.chat.completions.parse(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_model, # type: ignore
)
async def _create_completion(
self,
model: str,
messages: list[ChatCompletionMessageParam],
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel] | None = None,
max_tokens: int = 1024,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, Any]:
"""Generate response using Azure OpenAI client."""
# Convert messages to OpenAI format
openai_messages: list[ChatCompletionMessageParam] = []
for message in messages:
message.content = self._clean_input(message.content)
if message.role == 'user':
openai_messages.append({'role': 'user', 'content': message.content})
elif message.role == 'system':
openai_messages.append({'role': 'system', 'content': message.content})
# Ensure model is a string
model_name = self.model if self.model else 'gpt-4o-mini'
try:
response = await self.azure_client.chat.completions.create(
model=model_name,
messages=openai_messages,
temperature=float(self.temperature) if self.temperature is not None else 0.7,
max_tokens=max_tokens,
response_format={'type': 'json_object'},
)
result = response.choices[0].message.content or '{}'
# Parse JSON response
return json.loads(result)
except Exception as e:
logger.error(f'Error in Azure OpenAI LLM response: {e}')
raise
):
"""Create a regular completion with JSON format using Azure OpenAI."""
return await self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={'type': 'json_object'},
)

View File

@ -0,0 +1,217 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from abc import abstractmethod
from typing import Any, ClassVar
import openai
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ..prompts.models import Message
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4.1-mini'
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
class BaseOpenAIClient(LLMClient):
"""
Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).
This class contains shared logic for both OpenAI and Azure OpenAI clients,
reducing code duplication while allowing for implementation-specific differences.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int = DEFAULT_MAX_TOKENS,
):
if cache:
raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.max_tokens = max_tokens
@abstractmethod
async def _create_completion(
self,
model: str,
messages: list[ChatCompletionMessageParam],
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel] | None = None,
) -> Any:
"""Create a completion using the specific client implementation."""
pass
@abstractmethod
async def _create_structured_completion(
self,
model: str,
messages: list[ChatCompletionMessageParam],
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel],
) -> Any:
"""Create a structured completion using the specific client implementation."""
pass
def _convert_messages_to_openai_format(
self, messages: list[Message]
) -> list[ChatCompletionMessageParam]:
"""Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
m.content = self._clean_input(m.content)
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
return openai_messages
def _get_model_for_size(self, model_size: ModelSize) -> str:
"""Get the appropriate model name based on the requested size."""
if model_size == ModelSize.small:
return self.small_model or DEFAULT_SMALL_MODEL
else:
return self.model or DEFAULT_MODEL
def _handle_structured_response(self, response: Any) -> dict[str, Any]:
"""Handle structured response parsing and validation."""
response_object = response.choices[0].message
if response_object.parsed:
return response_object.parsed.model_dump()
elif response_object.refusal:
raise RefusalError(response_object.refusal)
else:
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
def _handle_json_response(self, response: Any) -> dict[str, Any]:
"""Handle JSON response parsing."""
result = response.choices[0].message.content or '{}'
return json.loads(result)
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, Any]:
"""Generate a response using the appropriate client implementation."""
openai_messages = self._convert_messages_to_openai_format(messages)
model = self._get_model_for_size(model_size)
try:
if response_model:
response = await self._create_structured_completion(
model=model,
messages=openai_messages,
temperature=self.temperature,
max_tokens=max_tokens or self.max_tokens,
response_model=response_model,
)
return self._handle_structured_response(response)
else:
response = await self._create_completion(
model=model,
messages=openai_messages,
temperature=self.temperature,
max_tokens=max_tokens or self.max_tokens,
)
return self._handle_json_response(response)
except openai.LengthFinishReasonError as e:
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""Generate a response with retry logic and error handling."""
if max_tokens is None:
max_tokens = self.max_tokens
retry_count = 0
last_error = None
# Add multilingual extraction instructions
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(
messages, response_model, max_tokens, model_size
)
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
raise
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
# Let OpenAI's client handle these retries
raise
except Exception as e:
last_error = e
# Don't retry if we've hit the max retries
if retry_count >= self.MAX_RETRIES:
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
raise
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')

View File

@ -14,50 +14,27 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import typing
from typing import ClassVar
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ..prompts.models import Message
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4.1-mini'
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .openai_base_client import BaseOpenAIClient
class OpenAIClient(LLMClient):
class OpenAIClient(BaseOpenAIClient):
"""
OpenAIClient is a client class for interacting with OpenAI's language models.
This class extends the LLMClient and provides methods to initialize the client,
get an embedder, and generate responses from the language model.
This class extends the BaseOpenAIClient and provides OpenAI-specific implementation
for creating completions.
Attributes:
client (AsyncOpenAI): The OpenAI client used to interact with the API.
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
@ -72,120 +49,47 @@ class OpenAIClient(LLMClient):
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
# removed caching to simplify the `generate_response` override
if cache:
raise NotImplementedError('Caching is not implemented for OpenAI')
super().__init__(config, cache, max_tokens)
if config is None:
config = LLMConfig()
super().__init__(config, cache)
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = client
self.max_tokens = max_tokens
async def _generate_response(
async def _create_structured_completion(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
m.content = self._clean_input(m.content)
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
if model_size == ModelSize.small:
model = self.small_model or DEFAULT_SMALL_MODEL
else:
model = self.model or DEFAULT_MODEL
model: str,
messages: list[ChatCompletionMessageParam],
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel],
):
"""Create a structured completion using OpenAI's beta parse API."""
return await self.client.beta.chat.completions.parse(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_model, # type: ignore
)
response = await self.client.beta.chat.completions.parse(
model=model,
messages=openai_messages,
temperature=self.temperature,
max_tokens=max_tokens or self.max_tokens,
response_format=response_model, # type: ignore
)
response_object = response.choices[0].message
if response_object.parsed:
return response_object.parsed.model_dump()
elif response_object.refusal:
raise RefusalError(response_object.refusal)
else:
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
except openai.LengthFinishReasonError as e:
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
async def generate_response(
async def _create_completion(
self,
messages: list[Message],
model: str,
messages: list[ChatCompletionMessageParam],
temperature: float | None,
max_tokens: int,
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
retry_count = 0
last_error = None
# Add multilingual extraction instructions
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(
messages, response_model, max_tokens, model_size
)
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
raise
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
# Let OpenAI's client handle these retries
raise
except Exception as e:
last_error = e
# Don't retry if we've hit the max retries
if retry_count >= self.MAX_RETRIES:
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
raise
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')
):
"""Create a regular completion with JSON format."""
return await self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={'type': 'json_object'},
)

View File

@ -78,9 +78,11 @@ The server uses the following environment variables:
- `MODEL_NAME`: OpenAI model name to use for LLM operations.
- `SMALL_MODEL_NAME`: OpenAI model name to use for smaller LLM operations.
- `LLM_TEMPERATURE`: Temperature for LLM responses (0.0-2.0).
- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI endpoint URL
- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI deployment name
- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI API version
- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI LLM endpoint URL
- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI LLM deployment name
- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI LLM API version
- `AZURE_OPENAI_EMBEDDING_API_KEY`: Optional Azure OpenAI Embedding deployment key (if other than `OPENAI_API_KEY`)
- `AZURE_OPENAI_EMBEDDING_ENDPOINT`: Optional Azure OpenAI Embedding endpoint URL
- `AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME`: Optional Azure OpenAI embedding deployment name
- `AZURE_OPENAI_EMBEDDING_API_VERSION`: Optional Azure OpenAI API version
- `AZURE_OPENAI_USE_MANAGED_IDENTITY`: Optional use Azure Managed Identities for authentication

View File

@ -367,7 +367,7 @@ class GraphitiEmbedderConfig(BaseModel):
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
@ -390,7 +390,9 @@ class GraphitiEmbedderConfig(BaseModel):
if not azure_openai_use_managed_identity:
# api key
api_key = os.environ.get('OPENAI_API_KEY', None)
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
'OPENAI_API_KEY', None
)
else:
# Managed identity
api_key = None