diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index e29a6a9d..98d7c0e0 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -6,6 +6,7 @@ import numpy as np from dotenv import load_dotenv import aiohttp import logging +from openai import AzureOpenAI logging.basicConfig(level=logging.INFO) @@ -32,11 +33,12 @@ os.mkdir(WORKING_DIR) async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - headers = { - "Content-Type": "application/json", - "api-key": AZURE_OPENAI_API_KEY, - } - endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}" + + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT + ) messages = [] if system_prompt: @@ -45,41 +47,30 @@ async def llm_model_func( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - payload = { - "messages": messages, - "temperature": kwargs.get("temperature", 0), - "top_p": kwargs.get("top_p", 1), - "n": kwargs.get("n", 1), - } - - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, headers=headers, json=payload) as response: - if response.status != 200: - raise ValueError( - f"Request failed with status {response.status}: {await response.text()}" - ) - result = await response.json() - return result["choices"][0]["message"]["content"] + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name". + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + return chat_completion.choices[0].message.content async def embedding_func(texts: list[str]) -> np.ndarray: - headers = { - "Content-Type": "application/json", - "api-key": AZURE_OPENAI_API_KEY, - } - endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}" - payload = {"input": texts} - - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, headers=headers, json=payload) as response: - if response.status != 200: - raise ValueError( - f"Request failed with status {response.status}: {await response.text()}" - ) - result = await response.json() - embeddings = [item["embedding"] for item in result["data"]] - return np.array(embeddings) + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_EMBEDDING_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT + ) + embedding = client.embeddings.create( + model=AZURE_EMBEDDING_DEPLOYMENT, + input=texts + ) + + embeddings = [item.embedding for item in embedding.data] + return np.array(embeddings) async def test_funcs():