mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
feat(builder): add Azure Open AI Compatibility (#269)
* feat(llm): add Azure OpenAI client and vectorization support * chore: add .DS_Store to .gitignore * refactor(llm):add description for api_version and default value * refactor(vectorize_model): added description for ap_version and default values for some params * refactor(openai_model): enhance docstring for Azure AD token and deployment parameters
This commit is contained in:
parent
671a9a016c
commit
6494fd20c0
1
.gitignore
vendored
1
.gitignore
vendored
@ -15,3 +15,4 @@
|
||||
.idea/
|
||||
.venv/
|
||||
__pycache__/
|
||||
.DS_Store
|
@ -12,19 +12,22 @@
|
||||
|
||||
|
||||
import json
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
import logging
|
||||
|
||||
from kag.interface import LLMClient
|
||||
from tenacity import retry, stop_after_attempt
|
||||
from typing import Callable
|
||||
|
||||
logging.getLogger("openai").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AzureADTokenProvider = Callable[[], str]
|
||||
|
||||
@LLMClient.register("maas")
|
||||
@LLMClient.register("openai")
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
"""
|
||||
A client class for interacting with the OpenAI API.
|
||||
@ -134,3 +137,119 @@ class OpenAIClient(LLMClient):
|
||||
except:
|
||||
return rsp
|
||||
return json_result
|
||||
@LLMClient.register("azure_openai")
|
||||
class AzureOpenAIClient (LLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
api_version: str = "2024-12-01-preview",
|
||||
temperature: float = 0.7,
|
||||
azure_deployment: str = None,
|
||||
timeout: float = None,
|
||||
azure_ad_token: str = None,
|
||||
azure_ad_token_provider: AzureADTokenProvider = None,
|
||||
):
|
||||
"""
|
||||
Initializes the AzureOpenAIClient instance.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing the Azure OpenAI API.
|
||||
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
|
||||
base_url (str): The base URL for the Azure OpenAI API.
|
||||
azure_deployment (str): The deployment name for the Azure OpenAI model
|
||||
model (str): The default model to use for requests.
|
||||
stream (bool, optional): Whether to stream the response. Defaults to False.
|
||||
temperature (float, optional): The temperature parameter for the model. Defaults to 0.7.
|
||||
timeout (float): The timeout duration for the service request. Defaults to None, means no timeout.
|
||||
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
|
||||
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
|
||||
"""
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.azure_deployment = azure_deployment
|
||||
self.model = model
|
||||
self.stream = stream
|
||||
self.temperature = temperature
|
||||
self.timeout = timeout
|
||||
self.api_version = api_version
|
||||
self.azure_ad_token = azure_ad_token
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider)
|
||||
self.check()
|
||||
|
||||
def __call__(self, prompt: str, image_url: str = None):
|
||||
"""
|
||||
Executes a model request when the object is called and returns the result.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt provided to the model.
|
||||
|
||||
Returns:
|
||||
str: The response content generated by the model.
|
||||
"""
|
||||
# Call the model with the given prompt and return the response
|
||||
if image_url:
|
||||
message = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
]
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=message,
|
||||
stream=self.stream,
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
rsp = response.choices[0].message.content
|
||||
return rsp
|
||||
|
||||
else:
|
||||
message = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=message,
|
||||
stream=self.stream,
|
||||
temperature=self.temperature,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
rsp = response.choices[0].message.content
|
||||
return rsp
|
||||
@retry(stop=stop_after_attempt(3))
|
||||
def call_with_json_parse(self, prompt):
|
||||
"""
|
||||
Calls the model and attempts to parse the response into JSON format.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt provided to the model.
|
||||
|
||||
Returns:
|
||||
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
|
||||
"""
|
||||
# Call the model and attempt to parse the response into JSON format
|
||||
rsp = self(prompt)
|
||||
_end = rsp.rfind("```")
|
||||
_start = rsp.find("```json")
|
||||
if _end != -1 and _start != -1:
|
||||
json_str = rsp[_start + len("```json") : _end].strip()
|
||||
else:
|
||||
json_str = rsp
|
||||
try:
|
||||
json_result = json.loads(json_str)
|
||||
except:
|
||||
return rsp
|
||||
return json_result
|
@ -10,9 +10,9 @@
|
||||
# or implied.
|
||||
|
||||
from typing import Union, Iterable
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
from kag.interface import VectorizeModelABC, EmbeddingVector
|
||||
|
||||
from typing import Callable
|
||||
|
||||
@VectorizeModelABC.register("openai")
|
||||
class OpenAIVectorizeModel(VectorizeModelABC):
|
||||
@ -65,3 +65,71 @@ class OpenAIVectorizeModel(VectorizeModelABC):
|
||||
else:
|
||||
assert len(results) == len(texts)
|
||||
return results
|
||||
|
||||
@VectorizeModelABC.register("azure_openai")
|
||||
class AzureOpenAIVectorizeModel(VectorizeModelABC):
|
||||
''' A class that extends the VectorizeModelABC base class.
|
||||
It invokes Azure OpenAI or Azure OpenAI-compatible embedding services to convert texts into embedding vectors.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str = "text-embedding-ada-002",
|
||||
api_version: str = "2024-12-01-preview",
|
||||
vector_dimensions: int = None,
|
||||
timeout: float = None,
|
||||
azure_deployment: str = None,
|
||||
azure_ad_token: str = None,
|
||||
azure_ad_token_provider: Callable = None,
|
||||
):
|
||||
"""
|
||||
Initializes the AzureOpenAIVectorizeModel instance.
|
||||
|
||||
Args:
|
||||
model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small".
|
||||
api_key (str, optional): The API key for accessing the Azure OpenAI service. Defaults to "".
|
||||
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
|
||||
base_url (str, optional): The base URL for the Azure OpenAI service. Defaults to "".
|
||||
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
|
||||
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
|
||||
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
|
||||
"""
|
||||
super().__init__(vector_dimensions)
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
azure_deployment=azure_deployment,
|
||||
model=model,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
)
|
||||
|
||||
def vectorize(
|
||||
self, texts: Union[str, Iterable[str]]
|
||||
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
|
||||
"""
|
||||
Vectorizes a text string into an embedding vector or multiple text strings into multiple embedding vectors.
|
||||
|
||||
Args:
|
||||
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
|
||||
|
||||
Returns:
|
||||
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
|
||||
"""
|
||||
results = self.client.embeddings.create(
|
||||
input=texts, model=self.model, timeout=self.timeout
|
||||
)
|
||||
results = [item.embedding for item in results.data]
|
||||
if isinstance(texts, str):
|
||||
assert len(results) == 1
|
||||
return results[0]
|
||||
else:
|
||||
assert len(results) == len(texts)
|
||||
return results
|
Loading…
x
Reference in New Issue
Block a user