mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config (#620)
* Azure OpenAI improvements and fixes; Improve Graphiti Azure OpenAI config * format
This commit is contained in:
parent
587f1b9876
commit
9cc2e86071
@ -14,60 +14,64 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig, ModelSize
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||
from .openai_base_client import BaseOpenAIClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAILLMClient(LLMClient):
|
||||
class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
||||
|
||||
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
|
||||
super().__init__(config, cache=False)
|
||||
self.azure_client = azure_client
|
||||
# Class-level constants
|
||||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
async def _generate_response(
|
||||
def __init__(
|
||||
self,
|
||||
messages: list[Message],
|
||||
azure_client: AsyncAzureOpenAI,
|
||||
config: LLMConfig | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
):
|
||||
super().__init__(config, cache=False, max_tokens=max_tokens)
|
||||
self.client = azure_client
|
||||
|
||||
async def _create_structured_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel],
|
||||
):
|
||||
"""Create a structured completion using Azure OpenAI's beta parse API."""
|
||||
return await self.client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_model, # type: ignore
|
||||
)
|
||||
|
||||
async def _create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate response using Azure OpenAI client."""
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for message in messages:
|
||||
message.content = self._clean_input(message.content)
|
||||
if message.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': message.content})
|
||||
elif message.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': message.content})
|
||||
|
||||
# Ensure model is a string
|
||||
model_name = self.model if self.model else 'gpt-4o-mini'
|
||||
|
||||
try:
|
||||
response = await self.azure_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=openai_messages,
|
||||
temperature=float(self.temperature) if self.temperature is not None else 0.7,
|
||||
max_tokens=max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
result = response.choices[0].message.content or '{}'
|
||||
|
||||
# Parse JSON response
|
||||
return json.loads(result)
|
||||
except Exception as e:
|
||||
logger.error(f'Error in Azure OpenAI LLM response: {e}')
|
||||
raise
|
||||
):
|
||||
"""Create a regular completion with JSON format using Azure OpenAI."""
|
||||
return await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
|
217
graphiti_core/llm_client/openai_base_client.py
Normal file
217
graphiti_core/llm_client/openai_base_client.py
Normal file
@ -0,0 +1,217 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import openai
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||
from .errors import RateLimitError, RefusalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gpt-4.1-mini'
|
||||
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
|
||||
|
||||
|
||||
class BaseOpenAIClient(LLMClient):
|
||||
"""
|
||||
Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).
|
||||
|
||||
This class contains shared logic for both OpenAI and Azure OpenAI clients,
|
||||
reducing code duplication while allowing for implementation-specific differences.
|
||||
"""
|
||||
|
||||
# Class-level constants
|
||||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
cache: bool = False,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
):
|
||||
if cache:
|
||||
raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
|
||||
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
super().__init__(config, cache)
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@abstractmethod
|
||||
async def _create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> Any:
|
||||
"""Create a completion using the specific client implementation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _create_structured_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel],
|
||||
) -> Any:
|
||||
"""Create a structured completion using the specific client implementation."""
|
||||
pass
|
||||
|
||||
def _convert_messages_to_openai_format(
|
||||
self, messages: list[Message]
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
m.content = self._clean_input(m.content)
|
||||
if m.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': m.content})
|
||||
elif m.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': m.content})
|
||||
return openai_messages
|
||||
|
||||
def _get_model_for_size(self, model_size: ModelSize) -> str:
|
||||
"""Get the appropriate model name based on the requested size."""
|
||||
if model_size == ModelSize.small:
|
||||
return self.small_model or DEFAULT_SMALL_MODEL
|
||||
else:
|
||||
return self.model or DEFAULT_MODEL
|
||||
|
||||
def _handle_structured_response(self, response: Any) -> dict[str, Any]:
|
||||
"""Handle structured response parsing and validation."""
|
||||
response_object = response.choices[0].message
|
||||
|
||||
if response_object.parsed:
|
||||
return response_object.parsed.model_dump()
|
||||
elif response_object.refusal:
|
||||
raise RefusalError(response_object.refusal)
|
||||
else:
|
||||
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
|
||||
|
||||
def _handle_json_response(self, response: Any) -> dict[str, Any]:
|
||||
"""Handle JSON response parsing."""
|
||||
result = response.choices[0].message.content or '{}'
|
||||
return json.loads(result)
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a response using the appropriate client implementation."""
|
||||
openai_messages = self._convert_messages_to_openai_format(messages)
|
||||
model = self._get_model_for_size(model_size)
|
||||
|
||||
try:
|
||||
if response_model:
|
||||
response = await self._create_structured_completion(
|
||||
model=model,
|
||||
messages=openai_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
response_model=response_model,
|
||||
)
|
||||
return self._handle_structured_response(response)
|
||||
else:
|
||||
response = await self._create_completion(
|
||||
model=model,
|
||||
messages=openai_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
)
|
||||
return self._handle_json_response(response)
|
||||
|
||||
except openai.LengthFinishReasonError as e:
|
||||
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
||||
except openai.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Generate a response with retry logic and error handling."""
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
response = await self._generate_response(
|
||||
messages, response_model, max_tokens, model_size
|
||||
)
|
||||
return response
|
||||
except (RateLimitError, RefusalError):
|
||||
# These errors should not trigger retries
|
||||
raise
|
||||
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
||||
# Let OpenAI's client handle these retries
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Don't retry if we've hit the max retries
|
||||
if retry_count >= self.MAX_RETRIES:
|
||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
||||
raise
|
||||
|
||||
retry_count += 1
|
||||
|
||||
# Construct a detailed error message for the LLM
|
||||
error_context = (
|
||||
f'The previous response attempt was invalid. '
|
||||
f'Error type: {e.__class__.__name__}. '
|
||||
f'Error details: {str(e)}. '
|
||||
f'Please try again with a valid response, ensuring the output matches '
|
||||
f'the expected format and constraints.'
|
||||
)
|
||||
|
||||
error_message = Message(role='user', content=error_context)
|
||||
messages.append(error_message)
|
||||
logger.warning(
|
||||
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
||||
)
|
||||
|
||||
# If we somehow get here, raise the last error
|
||||
raise last_error or Exception('Max retries exceeded with no specific error')
|
@ -14,50 +14,27 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import typing
|
||||
from typing import ClassVar
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
||||
from .errors import RateLimitError, RefusalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gpt-4.1-mini'
|
||||
DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
|
||||
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||
from .openai_base_client import BaseOpenAIClient
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
class OpenAIClient(BaseOpenAIClient):
|
||||
"""
|
||||
OpenAIClient is a client class for interacting with OpenAI's language models.
|
||||
|
||||
This class extends the LLMClient and provides methods to initialize the client,
|
||||
get an embedder, and generate responses from the language model.
|
||||
This class extends the BaseOpenAIClient and provides OpenAI-specific implementation
|
||||
for creating completions.
|
||||
|
||||
Attributes:
|
||||
client (AsyncOpenAI): The OpenAI client used to interact with the API.
|
||||
model (str): The model name to use for generating responses.
|
||||
temperature (float): The temperature to use for generating responses.
|
||||
max_tokens (int): The maximum number of tokens to generate in a response.
|
||||
|
||||
Methods:
|
||||
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
||||
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
|
||||
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
||||
Generates a response from the language model based on the provided messages.
|
||||
"""
|
||||
|
||||
# Class-level constants
|
||||
MAX_RETRIES: ClassVar[int] = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
@ -72,120 +49,47 @@ class OpenAIClient(LLMClient):
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
|
||||
"""
|
||||
# removed caching to simplify the `generate_response` override
|
||||
if cache:
|
||||
raise NotImplementedError('Caching is not implemented for OpenAI')
|
||||
super().__init__(config, cache, max_tokens)
|
||||
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
super().__init__(config, cache)
|
||||
|
||||
if client is None:
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
async def _generate_response(
|
||||
async def _create_structured_completion(
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
m.content = self._clean_input(m.content)
|
||||
if m.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': m.content})
|
||||
elif m.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': m.content})
|
||||
try:
|
||||
if model_size == ModelSize.small:
|
||||
model = self.small_model or DEFAULT_SMALL_MODEL
|
||||
else:
|
||||
model = self.model or DEFAULT_MODEL
|
||||
model: str,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel],
|
||||
):
|
||||
"""Create a structured completion using OpenAI's beta parse API."""
|
||||
return await self.client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_model, # type: ignore
|
||||
)
|
||||
|
||||
response = await self.client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
messages=openai_messages,
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
response_format=response_model, # type: ignore
|
||||
)
|
||||
|
||||
response_object = response.choices[0].message
|
||||
|
||||
if response_object.parsed:
|
||||
return response_object.parsed.model_dump()
|
||||
elif response_object.refusal:
|
||||
raise RefusalError(response_object.refusal)
|
||||
else:
|
||||
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
|
||||
except openai.LengthFinishReasonError as e:
|
||||
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
||||
except openai.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
||||
async def generate_response(
|
||||
async def _create_completion(
|
||||
self,
|
||||
messages: list[Message],
|
||||
model: str,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
temperature: float | None,
|
||||
max_tokens: int,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, typing.Any]:
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
# Add multilingual extraction instructions
|
||||
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
||||
|
||||
while retry_count <= self.MAX_RETRIES:
|
||||
try:
|
||||
response = await self._generate_response(
|
||||
messages, response_model, max_tokens, model_size
|
||||
)
|
||||
return response
|
||||
except (RateLimitError, RefusalError):
|
||||
# These errors should not trigger retries
|
||||
raise
|
||||
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
||||
# Let OpenAI's client handle these retries
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
# Don't retry if we've hit the max retries
|
||||
if retry_count >= self.MAX_RETRIES:
|
||||
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
||||
raise
|
||||
|
||||
retry_count += 1
|
||||
|
||||
# Construct a detailed error message for the LLM
|
||||
error_context = (
|
||||
f'The previous response attempt was invalid. '
|
||||
f'Error type: {e.__class__.__name__}. '
|
||||
f'Error details: {str(e)}. '
|
||||
f'Please try again with a valid response, ensuring the output matches '
|
||||
f'the expected format and constraints.'
|
||||
)
|
||||
|
||||
error_message = Message(role='user', content=error_context)
|
||||
messages.append(error_message)
|
||||
logger.warning(
|
||||
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
||||
)
|
||||
|
||||
# If we somehow get here, raise the last error
|
||||
raise last_error or Exception('Max retries exceeded with no specific error')
|
||||
):
|
||||
"""Create a regular completion with JSON format."""
|
||||
return await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
|
@ -78,9 +78,11 @@ The server uses the following environment variables:
|
||||
- `MODEL_NAME`: OpenAI model name to use for LLM operations.
|
||||
- `SMALL_MODEL_NAME`: OpenAI model name to use for smaller LLM operations.
|
||||
- `LLM_TEMPERATURE`: Temperature for LLM responses (0.0-2.0).
|
||||
- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI endpoint URL
|
||||
- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI deployment name
|
||||
- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI API version
|
||||
- `AZURE_OPENAI_ENDPOINT`: Optional Azure OpenAI LLM endpoint URL
|
||||
- `AZURE_OPENAI_DEPLOYMENT_NAME`: Optional Azure OpenAI LLM deployment name
|
||||
- `AZURE_OPENAI_API_VERSION`: Optional Azure OpenAI LLM API version
|
||||
- `AZURE_OPENAI_EMBEDDING_API_KEY`: Optional Azure OpenAI Embedding deployment key (if other than `OPENAI_API_KEY`)
|
||||
- `AZURE_OPENAI_EMBEDDING_ENDPOINT`: Optional Azure OpenAI Embedding endpoint URL
|
||||
- `AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME`: Optional Azure OpenAI embedding deployment name
|
||||
- `AZURE_OPENAI_EMBEDDING_API_VERSION`: Optional Azure OpenAI API version
|
||||
- `AZURE_OPENAI_USE_MANAGED_IDENTITY`: Optional use Azure Managed Identities for authentication
|
||||
|
@ -367,7 +367,7 @@ class GraphitiEmbedderConfig(BaseModel):
|
||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
@ -390,7 +390,9 @@ class GraphitiEmbedderConfig(BaseModel):
|
||||
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('OPENAI_API_KEY', None)
|
||||
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
||||
'OPENAI_API_KEY', None
|
||||
)
|
||||
else:
|
||||
# Managed identity
|
||||
api_key = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user