LightRAG/reproduce/Step_1_openai_compatible.py
yangdx 9923821d75 refactor: Remove deprecated max_token_size from embedding configuration
This parameter is no longer used. Its removal simplifies the API and clarifies that token length management is handled by upstream text chunking logic rather than the embedding wrapper.
2025-07-29 10:49:35 +08:00

87 lines
2.1 KiB
Python

import os
import json
import time
import asyncio
import numpy as np
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.kg.shared_storage import initialize_pipeline_status
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def insert_text(rag, file_path):
with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
retries = 0
max_retries = 3
while retries < max_retries:
try:
rag.insert(unique_contexts)
break
except Exception as e:
retries += 1
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
time.sleep(10)
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "mix"
WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(embedding_dim=4096, func=embedding_func),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
if __name__ == "__main__":
main()