mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Allow usage of different openai compatible clients in embedder and encoder (#279)
* allow usage of different openai compatible clients in embedder and encoder * azure openai * cross encoder example --------- Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
parent
55e308fb9f
commit
5cad6c8504
@ -205,6 +205,7 @@ from openai import AsyncAzureOpenAI
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.llm_client import OpenAIClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
|
||||
# Azure OpenAI configuration
|
||||
api_key = "<your-api-key>"
|
||||
@ -231,6 +232,10 @@ graphiti = Graphiti(
|
||||
embedding_model="text-embedding-3-small" # Use your Azure deployed embedding model name
|
||||
),
|
||||
client=azure_openai_client
|
||||
),
|
||||
# Optional: Configure the OpenAI cross encoder with Azure OpenAI
|
||||
cross_encoder=OpenAIRerankerClient(
|
||||
client=azure_openai_client
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Copyright 2025, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from .bge_reranker_client import BGERerankerClient
|
||||
from .client import CrossEncoderClient
|
||||
from .openai_reranker_client import OpenAIRerankerClient
|
||||
|
||||
__all__ = ['CrossEncoderClient', 'BGERerankerClient', 'OpenAIRerankerClient']
|
@ -18,7 +18,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..helpers import semaphore_gather
|
||||
@ -36,21 +36,29 @@ class BooleanClassifier(BaseModel):
|
||||
|
||||
|
||||
class OpenAIRerankerClient(CrossEncoderClient):
|
||||
def __init__(self, config: LLMConfig | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig | None = None,
|
||||
client: AsyncOpenAI | AsyncAzureOpenAI | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||
Initialize the OpenAIRerankerClient with the provided configuration and client.
|
||||
|
||||
This reranker uses the OpenAI API to run a simple boolean classifier prompt concurrently
|
||||
for each passage. Log-probabilities are used to rank the passages.
|
||||
|
||||
Args:
|
||||
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
|
||||
cache (bool): Whether to use caching for responses. Defaults to False.
|
||||
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
|
||||
client (AsyncOpenAI | AsyncAzureOpenAI | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
||||
"""
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
|
||||
self.config = config
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
if client is None:
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
self.client = client
|
||||
|
||||
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
|
||||
openai_messages_list: Any = [
|
||||
@ -62,7 +70,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
||||
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
|
||||
<PASSAGE>
|
||||
{passage}
|
||||
</PASSAGE>
|
||||
|
Loading…
x
Reference in New Issue
Block a user