graphiti/graphiti_core/embedder/azure_openai.py
Daniel Chalef 3d7e1a4b79
feat: add Azure client wrappers for embedding and LLM, integrate into server (#581)
* create wrappers for azure clients

* rremove unused crossencoder client

* format

* chore: update graphiti-core to 0.12.0rc5 and pydantic to 2.11.5

* Update graphiti_core/llm_client/azure_openai_client.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
2025-06-13 11:55:08 -04:00

65 lines
2.3 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from typing import Any
from openai import AsyncAzureOpenAI
from .client import EmbedderClient
logger = logging.getLogger(__name__)
class AzureOpenAIEmbedderClient(EmbedderClient):
"""Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
self.azure_client = azure_client
self.model = model
async def create(self, input_data: str | list[str] | Any) -> list[float]:
"""Create embeddings using Azure OpenAI client."""
try:
# Handle different input types
if isinstance(input_data, str):
text_input = [input_data]
elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
text_input = input_data
else:
# Convert to string list for other types
text_input = [str(input_data)]
response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
# Return the first embedding as a list of floats
return response.data[0].embedding
except Exception as e:
logger.error(f'Error in Azure OpenAI embedding: {e}')
raise
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
"""Create batch embeddings using Azure OpenAI client."""
try:
response = await self.azure_client.embeddings.create(
model=self.model, input=input_data_list
)
return [embedding.embedding for embedding in response.data]
except Exception as e:
logger.error(f'Error in Azure OpenAI batch embedding: {e}')
raise