LightRAG/examples/test_faiss.py

109 lines
3.1 KiB
Python
Raw Normal View History

2025-01-31 19:00:36 +05:30
import os
import logging
2025-03-03 18:33:42 +08:00
import asyncio
2025-01-31 19:00:36 +05:30
import numpy as np
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from openai import AzureOpenAI
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
2025-03-03 18:33:42 +08:00
from lightrag.kg.shared_storage import initialize_pipeline_status
2025-01-31 19:00:36 +05:30
2025-03-03 18:33:42 +08:00
WORKING_DIR = "./dickens"
2025-01-31 19:00:36 +05:30
# Configure Logging
logging.basicConfig(level=logging.INFO)
# Load environment variables from .env file
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
2025-01-31 19:05:47 +05:30
2025-01-31 19:00:36 +05:30
async def llm_model_func(
2025-01-31 19:05:47 +05:30
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
2025-01-31 19:00:36 +05:30
) -> str:
# Create a client for AzureOpenAI
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version=AZURE_OPENAI_API_VERSION,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
)
# Build the messages list for the conversation
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Call the LLM
chat_completion = client.chat.completions.create(
model=AZURE_OPENAI_DEPLOYMENT,
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
)
return chat_completion.choices[0].message.content
async def embedding_func(texts: list[str]) -> np.ndarray:
2025-01-31 19:05:47 +05:30
model = SentenceTransformer("all-MiniLM-L6-v2")
2025-01-31 19:00:36 +05:30
embeddings = model.encode(texts, convert_to_numpy=True)
return embeddings
2025-03-03 18:40:03 +08:00
2025-03-03 18:33:42 +08:00
async def initialize_rag():
2025-01-31 19:00:36 +05:30
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=8192,
func=embedding_func,
),
vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={
2025-02-20 14:12:19 +01:00
"cosine_better_than_threshold": 0.2 # Your desired threshold
2025-01-31 19:05:47 +05:30
},
2025-01-31 19:00:36 +05:30
)
2025-03-03 18:33:42 +08:00
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
2025-03-03 18:40:03 +08:00
2025-03-03 18:33:42 +08:00
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
2025-01-31 19:00:36 +05:30
# Insert the custom chunks into LightRAG
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")
rag.insert([book1.read(), book2.read()])
query_text = "What are the main themes?"
print("Result (Naive):")
print(rag.query(query_text, param=QueryParam(mode="naive")))
print("\nResult (Local):")
print(rag.query(query_text, param=QueryParam(mode="local")))
print("\nResult (Global):")
print(rag.query(query_text, param=QueryParam(mode="global")))
print("\nResult (Hybrid):")
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
if __name__ == "__main__":
2025-01-31 19:05:47 +05:30
main()