LightRAG/examples/lightrag_azure_openai_demo.py

127 lines
3.4 KiB
Python
Raw Normal View History

2024-10-18 15:32:58 +08:00
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
from dotenv import load_dotenv
import logging
from openai import AzureOpenAI
2025-03-03 18:33:42 +08:00
from lightrag.kg.shared_storage import initialize_pipeline_status
2024-10-18 15:32:58 +08:00
logging.basicConfig(level=logging.INFO)
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
2024-10-18 15:32:58 +08:00
) -> str:
client = AzureOpenAI(
2024-11-21 18:49:29 +08:00
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_OPENAI_API_VERSION,
2024-11-21 11:11:23 +00:00
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
2024-10-18 15:32:58 +08:00
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
chat_completion = client.chat.completions.create(
2024-11-21 11:11:23 +00:00
model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
)
return chat_completion.choices[0].message.content
2024-10-18 15:32:58 +08:00
async def embedding_func(texts: list[str]) -> np.ndarray:
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_EMBEDDING_API_VERSION,
2024-11-21 11:11:23 +00:00
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
2024-11-21 11:11:23 +00:00
embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
embeddings = [item.embedding for item in embedding.data]
return np.array(embeddings)
2024-10-18 15:32:58 +08:00
async def test_funcs():
result = await llm_model_func("How are you?")
print("Resposta do llm_model_func: ", result)
result = await embedding_func(["How are you?"])
print("Resultado do embedding_func: ", result.shape)
print("Dimensão da embedding: ", result.shape[1])
asyncio.run(test_funcs())
embedding_dimension = 3072
2025-03-04 12:25:07 +08:00
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
rag = asyncio.run(initialize_rag())
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")
rag.insert([book1.read(), book2.read()])
2025-03-03 18:33:42 +08:00
2025-03-04 12:25:07 +08:00
query_text = "What are the main themes?"
2024-10-18 15:32:58 +08:00
2025-03-04 12:25:07 +08:00
print("Result (Naive):")
print(rag.query(query_text, param=QueryParam(mode="naive")))
2024-10-18 15:32:58 +08:00
2025-03-04 12:25:07 +08:00
print("\nResult (Local):")
print(rag.query(query_text, param=QueryParam(mode="local")))
2024-10-18 15:32:58 +08:00
2025-03-04 12:25:07 +08:00
print("\nResult (Global):")
print(rag.query(query_text, param=QueryParam(mode="global")))
2024-10-18 15:32:58 +08:00
2025-03-04 12:25:07 +08:00
print("\nResult (Hybrid):")
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
2024-10-18 15:32:58 +08:00
2025-03-04 12:25:07 +08:00
if __name__ == "__main__":
main()