From 2a3d92b51502794c4e10a76ddf31e848ac6827a9 Mon Sep 17 00:00:00 2001 From: Magic_yuan <72277650+magicyuan876@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:37:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8AzureOpenAI=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=EF=BC=8C=E6=94=AF=E6=8C=81RPM/TPM=E9=99=90=E5=88=B6=E3=80=82?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=8E=9F=E5=85=88429=E5=93=8D=E5=BA=94?= =?UTF-8?q?=E5=8D=B3=E6=8A=9B=E5=87=BA=E5=BC=82=E5=B8=B8=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/lightrag_azure_openai_demo.py | 63 +++++++++++--------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index e29a6a9d..98d7c0e0 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -6,6 +6,7 @@ import numpy as np from dotenv import load_dotenv import aiohttp import logging +from openai import AzureOpenAI logging.basicConfig(level=logging.INFO) @@ -32,11 +33,12 @@ os.mkdir(WORKING_DIR) async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - headers = { - "Content-Type": "application/json", - "api-key": AZURE_OPENAI_API_KEY, - } - endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}" + + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT + ) messages = [] if system_prompt: @@ -45,41 +47,30 @@ async def llm_model_func( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - payload = { - "messages": messages, - "temperature": kwargs.get("temperature", 0), - "top_p": kwargs.get("top_p", 1), - "n": kwargs.get("n", 1), - } - - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, headers=headers, json=payload) as response: - if response.status != 200: - raise ValueError( - f"Request failed with status {response.status}: {await response.text()}" - ) - result = await response.json() - return result["choices"][0]["message"]["content"] + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name". + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + return chat_completion.choices[0].message.content async def embedding_func(texts: list[str]) -> np.ndarray: - headers = { - "Content-Type": "application/json", - "api-key": AZURE_OPENAI_API_KEY, - } - endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}" - payload = {"input": texts} - - async with aiohttp.ClientSession() as session: - async with session.post(endpoint, headers=headers, json=payload) as response: - if response.status != 200: - raise ValueError( - f"Request failed with status {response.status}: {await response.text()}" - ) - result = await response.json() - embeddings = [item["embedding"] for item in result["data"]] - return np.array(embeddings) + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_EMBEDDING_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT + ) + embedding = client.embeddings.create( + model=AZURE_EMBEDDING_DEPLOYMENT, + input=texts + ) + + embeddings = [item.embedding for item in embedding.data] + return np.array(embeddings) async def test_funcs():