mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
feat: add Azure client wrappers for embedding and LLM, integrate into server (#581)
* create wrappers for azure clients * rremove unused crossencoder client * format * chore: update graphiti-core to 0.12.0rc5 and pydantic to 2.11.5 * Update graphiti_core/llm_client/azure_openai_client.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
5287810d2d
commit
3d7e1a4b79
@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||
if len(top_logprobs) == 0:
|
||||
continue
|
||||
norm_logprobs = np.exp(top_logprobs[0].logprob)
|
||||
if top_logprobs[0].token.strip().split(" ")[0].lower() == "true":
|
||||
if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
|
||||
scores.append(norm_logprobs)
|
||||
else:
|
||||
scores.append(1 - norm_logprobs)
|
||||
|
64
graphiti_core/embedder/azure_openai.py
Normal file
64
graphiti_core/embedder/azure_openai.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
from .client import EmbedderClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAIEmbedderClient(EmbedderClient):
|
||||
"""Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
|
||||
|
||||
def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
|
||||
self.azure_client = azure_client
|
||||
self.model = model
|
||||
|
||||
async def create(self, input_data: str | list[str] | Any) -> list[float]:
|
||||
"""Create embeddings using Azure OpenAI client."""
|
||||
try:
|
||||
# Handle different input types
|
||||
if isinstance(input_data, str):
|
||||
text_input = [input_data]
|
||||
elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
|
||||
text_input = input_data
|
||||
else:
|
||||
# Convert to string list for other types
|
||||
text_input = [str(input_data)]
|
||||
|
||||
response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
|
||||
|
||||
# Return the first embedding as a list of floats
|
||||
return response.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.error(f'Error in Azure OpenAI embedding: {e}')
|
||||
raise
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
"""Create batch embeddings using Azure OpenAI client."""
|
||||
try:
|
||||
response = await self.azure_client.embeddings.create(
|
||||
model=self.model, input=input_data_list
|
||||
)
|
||||
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f'Error in Azure OpenAI batch embedding: {e}')
|
||||
raise
|
73
graphiti_core/llm_client/azure_openai_client.py
Normal file
73
graphiti_core/llm_client/azure_openai_client.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig, ModelSize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureOpenAILLMClient(LLMClient):
|
||||
"""Wrapper class for AsyncAzureOpenAI that implements the LLMClient interface."""
|
||||
|
||||
def __init__(self, azure_client: AsyncAzureOpenAI, config: LLMConfig | None = None):
|
||||
super().__init__(config, cache=False)
|
||||
self.azure_client = azure_client
|
||||
|
||||
async def _generate_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_tokens: int = 1024,
|
||||
model_size: ModelSize = ModelSize.medium,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate response using Azure OpenAI client."""
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for message in messages:
|
||||
message.content = self._clean_input(message.content)
|
||||
if message.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': message.content})
|
||||
elif message.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': message.content})
|
||||
|
||||
# Ensure model is a string
|
||||
model_name = self.model if self.model else 'gpt-4o-mini'
|
||||
|
||||
try:
|
||||
response = await self.azure_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=openai_messages,
|
||||
temperature=float(self.temperature) if self.temperature is not None else 0.7,
|
||||
max_tokens=max_tokens,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
result = response.choices[0].message.content or '{}'
|
||||
|
||||
# Parse JSON response
|
||||
return json.loads(result)
|
||||
except Exception as e:
|
||||
logger.error(f'Error in Azure OpenAI LLM response: {e}')
|
||||
raise
|
@ -19,12 +19,12 @@ from openai import AsyncAzureOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
@ -37,6 +37,7 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
||||
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
||||
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
||||
@ -282,11 +283,11 @@ class GraphitiLLMConfig(BaseModel):
|
||||
|
||||
return config
|
||||
|
||||
def create_client(self) -> LLMClient | None:
|
||||
def create_client(self) -> LLMClient:
|
||||
"""Create an LLM client based on this configuration.
|
||||
|
||||
Returns:
|
||||
LLMClient instance if able, None otherwise
|
||||
LLMClient instance
|
||||
"""
|
||||
|
||||
if self.azure_openai_endpoint is not None:
|
||||
@ -294,26 +295,41 @@ class GraphitiLLMConfig(BaseModel):
|
||||
if self.azure_openai_use_managed_identity:
|
||||
# Use managed identity for authentication
|
||||
token_provider = create_azure_credential_token_provider()
|
||||
return AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
return AzureOpenAILLMClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
),
|
||||
config=LLMConfig(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
),
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use API key for authentication
|
||||
return AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
return AzureOpenAILLMClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
),
|
||||
config=LLMConfig(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
return None
|
||||
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
|
||||
if not self.api_key:
|
||||
return None
|
||||
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
||||
|
||||
llm_client_config = LLMConfig(
|
||||
api_key=self.api_key, model=self.model, small_model=self.small_model
|
||||
@ -324,17 +340,6 @@ class GraphitiLLMConfig(BaseModel):
|
||||
|
||||
return OpenAIClient(config=llm_client_config)
|
||||
|
||||
def create_cross_encoder_client(self) -> CrossEncoderClient | None:
|
||||
"""Create a cross-encoder client based on this configuration."""
|
||||
if self.azure_openai_endpoint is not None:
|
||||
client = self.create_client()
|
||||
return OpenAIRerankerClient(client=client)
|
||||
else:
|
||||
llm_client_config = LLMConfig(
|
||||
api_key=self.api_key, model=self.model, small_model=self.small_model
|
||||
)
|
||||
return OpenAIRerankerClient(config=llm_client_config)
|
||||
|
||||
|
||||
class GraphitiEmbedderConfig(BaseModel):
|
||||
"""Configuration for the embedder client.
|
||||
@ -404,19 +409,25 @@ class GraphitiEmbedderConfig(BaseModel):
|
||||
if self.azure_openai_use_managed_identity:
|
||||
# Use managed identity for authentication
|
||||
token_provider = create_azure_credential_token_provider()
|
||||
return AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
return AzureOpenAIEmbedderClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
),
|
||||
model=self.model,
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use API key for authentication
|
||||
return AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
return AzureOpenAIEmbedderClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
),
|
||||
model=self.model,
|
||||
)
|
||||
else:
|
||||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
@ -570,7 +581,6 @@ async def initialize_graphiti():
|
||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||
|
||||
embedder_client = config.embedder.create_client()
|
||||
cross_encoder_client = config.llm.create_cross_encoder_client()
|
||||
|
||||
# Initialize Graphiti client
|
||||
graphiti_client = Graphiti(
|
||||
@ -579,7 +589,6 @@ async def initialize_graphiti():
|
||||
password=config.neo4j.password,
|
||||
llm_client=llm_client,
|
||||
embedder=embedder_client,
|
||||
cross_encoder=cross_encoder_client,
|
||||
)
|
||||
|
||||
# Destroy graph if requested
|
||||
|
@ -7,6 +7,6 @@ requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"mcp>=1.5.0",
|
||||
"openai>=1.68.2",
|
||||
"graphiti-core>=0.8.2",
|
||||
"graphiti-core>=0.11.6",
|
||||
"azure-identity>=1.21.0",
|
||||
]
|
||||
|
28
mcp_server/uv.lock
generated
28
mcp_server/uv.lock
generated
@ -282,8 +282,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.11.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "0.12.0rc5"
|
||||
source = { directory = "../" }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
{ name = "neo4j" },
|
||||
@ -293,9 +293,19 @@ dependencies = [
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "tenacity" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/30/94/3f84400e5f02ea8e9dc79784202de4173cbc16f4b3ad1bd4302da888e4d8/graphiti_core-0.11.6.tar.gz", hash = "sha256:31d26621834d7d4b8865059ab749feb18af15937b59c69598a640a5dfabea331", size = 71928 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/2e/c8f22f01585bf173d1c82f6d4615511aebc75aeda764c69aa394446fa93c/graphiti_core-0.11.6-py3-none-any.whl", hash = "sha256:6ec4807a884f5ea88b942d0c8b7bcd2e107c7358ab4f98ef2a2092c229929707", size = 111001 },
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.49.0" },
|
||||
{ name = "diskcache", specifier = ">=5.6.3" },
|
||||
{ name = "google-genai", marker = "extra == 'google-genai'", specifier = ">=1.8.0" },
|
||||
{ name = "groq", marker = "extra == 'groq'", specifier = ">=0.2.0" },
|
||||
{ name = "neo4j", specifier = ">=5.26.0" },
|
||||
{ name = "numpy", specifier = ">=1.0.0" },
|
||||
{ name = "openai", specifier = ">=1.53.0" },
|
||||
{ name = "pydantic", specifier = ">=2.11.5" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
||||
{ name = "tenacity", specifier = ">=9.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -459,7 +469,7 @@ dependencies = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "azure-identity", specifier = ">=1.21.0" },
|
||||
{ name = "graphiti-core", specifier = ">=0.8.2" },
|
||||
{ name = "graphiti-core", directory = "../" },
|
||||
{ name = "mcp", specifier = ">=1.5.0" },
|
||||
{ name = "openai", specifier = ">=1.68.2" },
|
||||
]
|
||||
@ -594,7 +604,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.4"
|
||||
version = "2.11.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-types" },
|
||||
@ -602,9 +612,9 @@ dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "typing-inspection" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f0/86/8ce9040065e8f924d642c58e4a344e33163a07f6b57f836d0d734e0ad3fb/pydantic-2.11.5.tar.gz", hash = "sha256:7f853db3d0ce78ce8bbb148c401c2cdd6431b3473c0cdff2755c7690952a7b7a", size = 787102 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/69/831ed22b38ff9b4b64b66569f0e5b7b97cf3638346eb95a2147fdb49ad5f/pydantic-2.11.5-py3-none-any.whl", hash = "sha256:f9c26ba06f9747749ca1e5c94d6a85cb84254577553c8785576fd38fa64dc0f7", size = 444229 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user