feat: add Azure client wrappers for embedding and LLM, integrate into server (#581)

* create wrappers for azure clients

* rremove unused crossencoder client

* format

* chore: update graphiti-core to 0.12.0rc5 and pydantic to 2.11.5

* Update graphiti_core/llm_client/azure_openai_client.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Daniel Chalef 2025-06-13 08:55:08 -07:00 committed by GitHub
parent 5287810d2d
commit 3d7e1a4b79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 207 additions and 51 deletions

View File

@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
if len(top_logprobs) == 0:
continue
norm_logprobs = np.exp(top_logprobs[0].logprob)
if top_logprobs[0].token.strip().split(" ")[0].lower() == "true":
if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
scores.append(norm_logprobs)
else:
scores.append(1 - norm_logprobs)

View File

@ -0,0 +1,64 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from typing import Any
from openai import AsyncAzureOpenAI
from .client import EmbedderClient
logger = logging.getLogger(__name__)
class AzureOpenAIEmbedderClient(EmbedderClient):
"""Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
self.azure_client = azure_client
self.model = model
async def create(self, input_data: str | list[str] | Any) -> list[float]:
"""Create embeddings using Azure OpenAI client."""
try:
# Handle different input types
if isinstance(input_data, str):
text_input = [input_data]
elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
text_input = input_data
else:
# Convert to string list for other types
text_input = [str(input_data)]
response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
# Return the first embedding as a list of floats
return response.data[0].embedding
except Exception as e:
logger.error(f'Error in Azure OpenAI embedding: {e}')
raise
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
"""Create batch embeddings using Azure OpenAI client."""
try:
response = await self.azure_client.embeddings.create(
model=self.model, input=input_data_list
)
return [embedding.embedding for embedding in response.data]
except Exception as e:
logger.error(f'Error in Azure OpenAI batch embedding: {e}')
raise

View File

@ -0,0 +1,73 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from typing import Any
from openai import AsyncAzureOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig, ModelSize
logger = logging.getLogger(__name__)
class AzureOpenAILLMClient(LLMClient):
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
super().__init__(config, cache=False)
self.azure_client = azure_client
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = 1024,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, Any]:
"""Generate response using Azure OpenAI client."""
# Convert messages to OpenAI format
openai_messages: list[ChatCompletionMessageParam] = []
for message in messages:
message.content = self._clean_input(message.content)
if message.role == 'user':
openai_messages.append({'role': 'user', 'content': message.content})
elif message.role == 'system':
openai_messages.append({'role': 'system', 'content': message.content})
# Ensure model is a string
model_name = self.model if self.model else 'gpt-4o-mini'
try:
response = await self.azure_client.chat.completions.create(
model=model_name,
messages=openai_messages,
temperature=float(self.temperature) if self.temperature is not None else 0.7,
max_tokens=max_tokens,
response_format={'type': 'json_object'},
)
result = response.choices[0].message.content or '{}'
# Parse JSON response
return json.loads(result)
except Exception as e:
logger.error(f'Error in Azure OpenAI LLM response: {e}')
raise

View File

@ -19,12 +19,12 @@ from openai import AsyncAzureOpenAI
from pydantic import BaseModel, Field
from graphiti_core import Graphiti
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.openai_client import OpenAIClient
from graphiti_core.nodes import EpisodeType, EpisodicNode
@ -37,6 +37,7 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
SMALL_LLM_MODEL = 'gpt-4.1-nano'
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
@ -282,11 +283,11 @@ class GraphitiLLMConfig(BaseModel):
return config
def create_client(self) -> LLMClient | None:
def create_client(self) -> LLMClient:
"""Create an LLM client based on this configuration.
Returns:
LLMClient instance if able, None otherwise
LLMClient instance
"""
if self.azure_openai_endpoint is not None:
@ -294,26 +295,41 @@ class GraphitiLLMConfig(BaseModel):
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
return AzureOpenAILLMClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
config=LLMConfig(
api_key=self.api_key,
model=self.model,
small_model=self.small_model,
temperature=self.temperature,
),
)
elif self.api_key:
# Use API key for authentication
return AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
return AzureOpenAILLMClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
config=LLMConfig(
api_key=self.api_key,
model=self.model,
small_model=self.small_model,
temperature=self.temperature,
),
)
else:
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
return None
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
if not self.api_key:
return None
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
llm_client_config = LLMConfig(
api_key=self.api_key, model=self.model, small_model=self.small_model
@ -324,17 +340,6 @@ class GraphitiLLMConfig(BaseModel):
return OpenAIClient(config=llm_client_config)
def create_cross_encoder_client(self) -> CrossEncoderClient | None:
"""Create a cross-encoder client based on this configuration."""
if self.azure_openai_endpoint is not None:
client = self.create_client()
return OpenAIRerankerClient(client=client)
else:
llm_client_config = LLMConfig(
api_key=self.api_key, model=self.model, small_model=self.small_model
)
return OpenAIRerankerClient(config=llm_client_config)
class GraphitiEmbedderConfig(BaseModel):
"""Configuration for the embedder client.
@ -404,19 +409,25 @@ class GraphitiEmbedderConfig(BaseModel):
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
model=self.model,
)
elif self.api_key:
# Use API key for authentication
return AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
model=self.model,
)
else:
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
@ -570,7 +581,6 @@ async def initialize_graphiti():
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
embedder_client = config.embedder.create_client()
cross_encoder_client = config.llm.create_cross_encoder_client()
# Initialize Graphiti client
graphiti_client = Graphiti(
@ -579,7 +589,6 @@ async def initialize_graphiti():
password=config.neo4j.password,
llm_client=llm_client,
embedder=embedder_client,
cross_encoder=cross_encoder_client,
)
# Destroy graph if requested

View File

@ -7,6 +7,6 @@ requires-python = ">=3.10"
dependencies = [
"mcp>=1.5.0",
"openai>=1.68.2",
"graphiti-core>=0.8.2",
"graphiti-core>=0.11.6",
"azure-identity>=1.21.0",
]

28
mcp_server/uv.lock generated
View File

@ -282,8 +282,8 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.11.6"
source = { registry = "https://pypi.org/simple" }
version = "0.12.0rc5"
source = { directory = "../" }
dependencies = [
{ name = "diskcache" },
{ name = "neo4j" },
@ -293,9 +293,19 @@ dependencies = [
{ name = "python-dotenv" },
{ name = "tenacity" },
]
sdist = { url = "https://files.pythonhosted.org/packages/30/94/3f84400e5f02ea8e9dc79784202de4173cbc16f4b3ad1bd4302da888e4d8/graphiti_core-0.11.6.tar.gz", hash = "sha256:31d26621834d7d4b8865059ab749feb18af15937b59c69598a640a5dfabea331", size = 71928 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ac/2e/c8f22f01585bf173d1c82f6d4615511aebc75aeda764c69aa394446fa93c/graphiti_core-0.11.6-py3-none-any.whl", hash = "sha256:6ec4807a884f5ea88b942d0c8b7bcd2e107c7358ab4f98ef2a2092c229929707", size = 111001 },
[package.metadata]
requires-dist = [
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
{ name = "diskcache", specifier = ">=5.6.3" },
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = ">=1.8.0" },
{ name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" },
{ name = "neo4j", specifier = ">=5.26.0" },
{ name = "numpy", specifier = ">=1.0.0" },
{ name = "openai", specifier = ">=1.53.0" },
{ name = "pydantic", specifier = ">=2.11.5" },
{ name = "python-dotenv", specifier = ">=1.0.1" },
{ name = "tenacity", specifier = ">=9.0.0" },
]
[[package]]
@ -459,7 +469,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "azure-identity", specifier = ">=1.21.0" },
{ name = "graphiti-core", specifier = ">=0.8.2" },
{ name = "graphiti-core", directory = "../" },
{ name = "mcp", specifier = ">=1.5.0" },
{ name = "openai", specifier = ">=1.68.2" },
]
@ -594,7 +604,7 @@ wheels = [
[[package]]
name = "pydantic"
version = "2.11.4"
version = "2.11.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "annotated-types" },
@ -602,9 +612,9 @@ dependencies = [
{ name = "typing-extensions" },
{ name = "typing-inspection" },
]
sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 }
sdist = { url = "https://files.pythonhosted.org/packages/f0/86/8ce9040065e8f924d642c58e4a344e33163a07f6b57f836d0d734e0ad3fb/pydantic-2.11.5.tar.gz", hash = "sha256:7f853db3d0ce78ce8bbb148c401c2cdd6431b3473c0cdff2755c7690952a7b7a", size = 787102 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 },
{ url = "https://files.pythonhosted.org/packages/b5/69/831ed22b38ff9b4b64b66569f0e5b7b97cf3638346eb95a2147fdb49ad5f/pydantic-2.11.5-py3-none-any.whl", hash = "sha256:f9c26ba06f9747749ca1e5c94d6a85cb84254577553c8785576fd38fa64dc0f7", size = 444229 },
]
[[package]]