diff --git a/.gitignore b/.gitignore index 3dfa7d36..e7450b6e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ .idea/ .venv/ __pycache__/ +.DS_Store \ No newline at end of file diff --git a/kag/common/llm/openai_client.py b/kag/common/llm/openai_client.py index 0c8ff3aa..e4af7e2e 100644 --- a/kag/common/llm/openai_client.py +++ b/kag/common/llm/openai_client.py @@ -12,19 +12,22 @@ import json -from openai import OpenAI +from openai import OpenAI, AzureOpenAI import logging from kag.interface import LLMClient from tenacity import retry, stop_after_attempt +from typing import Callable logging.getLogger("openai").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) logger = logging.getLogger(__name__) +AzureADTokenProvider = Callable[[], str] @LLMClient.register("maas") @LLMClient.register("openai") + class OpenAIClient(LLMClient): """ A client class for interacting with the OpenAI API. @@ -134,3 +137,119 @@ class OpenAIClient(LLMClient): except: return rsp return json_result +@LLMClient.register("azure_openai") +class AzureOpenAIClient (LLMClient): + def __init__( + self, + api_key: str, + base_url: str, + model: str, + stream: bool = False, + api_version: str = "2024-12-01-preview", + temperature: float = 0.7, + azure_deployment: str = None, + timeout: float = None, + azure_ad_token: str = None, + azure_ad_token_provider: AzureADTokenProvider = None, + ): + """ + Initializes the AzureOpenAIClient instance. + + Args: + api_key (str): The API key for accessing the Azure OpenAI API. + api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview"). + base_url (str): The base URL for the Azure OpenAI API. + azure_deployment (str): The deployment name for the Azure OpenAI model + model (str): The default model to use for requests. + stream (bool, optional): Whether to stream the response. Defaults to False. + temperature (float, optional): The temperature parameter for the model. Defaults to 0.7. + timeout (float): The timeout duration for the service request. Defaults to None, means no timeout. + azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id + azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request. + azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`. + Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs. + """ + + self.api_key = api_key + self.base_url = base_url + self.azure_deployment = azure_deployment + self.model = model + self.stream = stream + self.temperature = temperature + self.timeout = timeout + self.api_version = api_version + self.azure_ad_token = azure_ad_token + self.azure_ad_token_provider = azure_ad_token_provider + self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider) + self.check() + + def __call__(self, prompt: str, image_url: str = None): + """ + Executes a model request when the object is called and returns the result. + + Parameters: + prompt (str): The prompt provided to the model. + + Returns: + str: The response content generated by the model. + """ + # Call the model with the given prompt and return the response + if image_url: + message = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ] + response = self.client.chat.completions.create( + model=self.model, + messages=message, + stream=self.stream, + temperature=self.temperature, + timeout=self.timeout, + ) + rsp = response.choices[0].message.content + return rsp + + else: + message = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": prompt}, + ] + response = self.client.chat.completions.create( + model=self.model, + messages=message, + stream=self.stream, + temperature=self.temperature, + timeout=self.timeout, + ) + rsp = response.choices[0].message.content + return rsp + @retry(stop=stop_after_attempt(3)) + def call_with_json_parse(self, prompt): + """ + Calls the model and attempts to parse the response into JSON format. + + Parameters: + prompt (str): The prompt provided to the model. + + Returns: + Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response. + """ + # Call the model and attempt to parse the response into JSON format + rsp = self(prompt) + _end = rsp.rfind("```") + _start = rsp.find("```json") + if _end != -1 and _start != -1: + json_str = rsp[_start + len("```json") : _end].strip() + else: + json_str = rsp + try: + json_result = json.loads(json_str) + except: + return rsp + return json_result \ No newline at end of file diff --git a/kag/common/vectorize_model/openai_model.py b/kag/common/vectorize_model/openai_model.py index 133b13a9..e03216de 100644 --- a/kag/common/vectorize_model/openai_model.py +++ b/kag/common/vectorize_model/openai_model.py @@ -10,9 +10,9 @@ # or implied. from typing import Union, Iterable -from openai import OpenAI +from openai import OpenAI, AzureOpenAI from kag.interface import VectorizeModelABC, EmbeddingVector - +from typing import Callable @VectorizeModelABC.register("openai") class OpenAIVectorizeModel(VectorizeModelABC): @@ -65,3 +65,71 @@ class OpenAIVectorizeModel(VectorizeModelABC): else: assert len(results) == len(texts) return results + +@VectorizeModelABC.register("azure_openai") +class AzureOpenAIVectorizeModel(VectorizeModelABC): + ''' A class that extends the VectorizeModelABC base class. + It invokes Azure OpenAI or Azure OpenAI-compatible embedding services to convert texts into embedding vectors. + ''' + + def __init__( + self, + base_url: str, + api_key: str, + model: str = "text-embedding-ada-002", + api_version: str = "2024-12-01-preview", + vector_dimensions: int = None, + timeout: float = None, + azure_deployment: str = None, + azure_ad_token: str = None, + azure_ad_token_provider: Callable = None, + ): + """ + Initializes the AzureOpenAIVectorizeModel instance. + + Args: + model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small". + api_key (str, optional): The API key for accessing the Azure OpenAI service. Defaults to "". + api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview"). + base_url (str, optional): The base URL for the Azure OpenAI service. Defaults to "". + vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None. + azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id + azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request. + azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`. + Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs. + """ + super().__init__(vector_dimensions) + self.model = model + self.timeout = timeout + self.client = AzureOpenAI( + api_key=api_key, + base_url=base_url, + azure_deployment=azure_deployment, + model=model, + api_version=api_version, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + ) + + def vectorize( + self, texts: Union[str, Iterable[str]] + ) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]: + """ + Vectorizes a text string into an embedding vector or multiple text strings into multiple embedding vectors. + + Args: + texts (Union[str, Iterable[str]]): The text or texts to vectorize. + + Returns: + Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s). + """ + results = self.client.embeddings.create( + input=texts, model=self.model, timeout=self.timeout + ) + results = [item.embedding for item in results.data] + if isinstance(texts, str): + assert len(results) == 1 + return results[0] + else: + assert len(results) == len(texts) + return results \ No newline at end of file