LightRAG/lightrag/api/ollama_lightrag_server.py

473 lines
16 KiB
Python
Raw Normal View History

2024-12-16 01:05:49 +01:00
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from pydantic import BaseModel
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embed
2024-12-16 01:05:49 +01:00
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
2024-12-17 23:36:30 +01:00
from pathlib import Path
import shutil
2024-12-17 23:51:49 +01:00
import aiofiles
2024-12-19 11:44:01 +01:00
from ascii_colors import trace_exception
from fastapi import FastAPI, HTTPException
import os
from typing import Optional
from fastapi import FastAPI, Depends, HTTPException, Security
from fastapi.security import APIKeyHeader
import os
import argparse
from typing import Optional
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
from fastapi import HTTPException
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
def parse_args():
parser = argparse.ArgumentParser(
2024-12-17 23:36:30 +01:00
description="LightRAG FastAPI Server with separate working and input directories"
2024-12-16 01:05:49 +01:00
)
# Server configuration
2024-12-19 11:44:01 +01:00
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
2024-12-17 23:36:30 +01:00
# Directory configuration
2024-12-19 11:44:01 +01:00
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
2024-12-16 01:05:49 +01:00
# Model configuration
2024-12-19 11:44:01 +01:00
parser.add_argument(
"--model",
default="mistral-nemo:latest",
help="LLM model name (default: mistral-nemo:latest)",
)
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
)
parser.add_argument(
"--ollama-host",
default="http://localhost:11434",
help="Ollama host URL (default: http://localhost:11434)",
)
2024-12-16 01:05:49 +01:00
# RAG configuration
2024-12-19 11:44:01 +01:00
parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
2024-12-16 01:05:49 +01:00
# Logging configuration
2024-12-19 11:44:01 +01:00
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None)
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
return parser.parse_args()
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
class DocumentManager:
"""Handles document operations and tracking"""
2024-12-19 11:44:01 +01:00
def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
2024-12-17 23:36:30 +01:00
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
2024-12-19 11:44:01 +01:00
for file_path in self.input_dir.rglob(f"*{ext}"):
2024-12-17 23:36:30 +01:00
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
2024-12-16 01:05:49 +01:00
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
class QueryResponse(BaseModel):
response: str
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
def create_app(args):
# Setup logging
2024-12-19 11:44:01 +01:00
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
2024-12-16 01:05:49 +01:00
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
2024-12-16 01:05:49 +01:00
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "",
version="1.0.0",
openapi_tags=[{"name": "api"}]
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
2024-12-16 01:05:49 +01:00
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
2024-12-16 01:05:49 +01:00
# Create working directory if it doesn't exist
2024-12-17 23:36:30 +01:00
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
2024-12-16 01:05:49 +01:00
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=ollama_model_complete,
llm_model_name=args.model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
2024-12-19 11:44:01 +01:00
llm_model_kwargs={
"host": args.ollama_host,
"options": {"num_ctx": args.max_tokens},
},
2024-12-16 01:05:49 +01:00
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: ollama_embed(
2024-12-16 01:05:49 +01:00
texts, embed_model=args.embedding_model, host=args.ollama_host
),
),
)
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
@app.on_event("startup")
async def startup_event():
2024-12-17 23:36:30 +01:00
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
2024-12-17 23:51:49 +01:00
# Use async file reading
2024-12-19 11:44:01 +01:00
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
2024-12-17 23:51:49 +01:00
content = await f.read()
# Use the async version of insert directly
await rag.ainsert(content)
2024-12-17 23:36:30 +01:00
doc_manager.mark_as_indexed(file_path)
logging.info(f"Indexed file: {file_path}")
except Exception as e:
2024-12-18 01:37:16 +01:00
trace_exception(e)
2024-12-17 23:36:30 +01:00
logging.error(f"Error indexing file {file_path}: {str(e)}")
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
2024-12-17 23:36:30 +01:00
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
2024-12-16 01:05:49 +01:00
try:
2024-12-17 23:36:30 +01:00
new_files = doc_manager.scan_directory()
indexed_count = 0
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
for file_path in new_files:
try:
2024-12-19 11:44:01 +01:00
with open(file_path, "r", encoding="utf-8") as f:
2024-12-17 23:36:30 +01:00
content = f.read()
2024-12-26 22:48:52 +01:00
await rag.ainsert(content)
2024-12-17 23:36:30 +01:00
doc_manager.mark_as_indexed(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
return {
"status": "success",
"indexed_count": indexed_count,
2024-12-19 11:44:01 +01:00
"total_documents": len(doc_manager.indexed_files),
2024-12-17 23:36:30 +01:00
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
2024-12-17 23:36:30 +01:00
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
2024-12-19 11:44:01 +01:00
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
2024-12-17 23:36:30 +01:00
)
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
# Immediately index the uploaded file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
2024-12-26 22:48:52 +01:00
await rag.ainsert(content)
2024-12-17 23:36:30 +01:00
doc_manager.mark_as_indexed(file_path)
2024-12-19 11:44:01 +01:00
2024-12-17 23:36:30 +01:00
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
2024-12-19 11:44:01 +01:00
"total_documents": len(doc_manager.indexed_files),
2024-12-17 23:36:30 +01:00
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
2024-12-16 01:05:49 +01:00
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)])
2024-12-16 01:05:49 +01:00
async def query_text(request: QueryRequest):
try:
2024-12-18 00:40:33 +01:00
response = await rag.aquery(
2024-12-16 01:05:49 +01:00
request.query,
2024-12-26 23:39:10 +01:00
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
),
2024-12-16 01:05:49 +01:00
)
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
2024-12-16 01:05:49 +01:00
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
2024-12-26 23:39:10 +01:00
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
2024-12-16 01:05:49 +01:00
)
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
async def stream_generator():
async for chunk in response:
yield chunk
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
return stream_generator()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
2024-12-16 01:05:49 +01:00
async def insert_text(request: InsertTextRequest):
try:
2024-12-26 22:48:52 +01:00
await rag.ainsert(request.text)
2024-12-16 01:05:49 +01:00
return InsertResponse(
status="success",
message="Text successfully inserted",
2024-12-19 11:44:01 +01:00
document_count=len(rag),
2024-12-16 01:05:49 +01:00
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
2024-12-19 11:44:01 +01:00
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
2024-12-16 01:05:49 +01:00
try:
content = await file.read()
2024-12-19 11:44:01 +01:00
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
2024-12-26 22:48:52 +01:00
await rag.ainsert(text)
2024-12-16 01:05:49 +01:00
else:
raise HTTPException(
status_code=400,
2024-12-19 11:44:01 +01:00
detail="Unsupported file type. Only .txt and .md files are supported",
2024-12-16 01:05:49 +01:00
)
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
2024-12-26 22:48:52 +01:00
document_count=1,
2024-12-16 01:05:49 +01:00
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
2024-12-16 01:05:49 +01:00
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
for file in files:
try:
content = await file.read()
2024-12-19 11:44:01 +01:00
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
2024-12-26 22:48:52 +01:00
await rag.ainsert(text)
2024-12-16 01:05:49 +01:00
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
2024-12-26 22:48:52 +01:00
document_count=len(files),
2024-12-16 01:05:49 +01:00
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)])
2024-12-16 01:05:49 +01:00
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
2024-12-19 11:44:01 +01:00
document_count=0,
2024-12-16 01:05:49 +01:00
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
2024-12-17 23:36:30 +01:00
async def get_status():
"""Get current system status"""
2024-12-16 01:05:49 +01:00
return {
"status": "healthy",
2024-12-17 23:36:30 +01:00
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
2024-12-16 01:05:49 +01:00
"configuration": {
"model": args.model,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
2024-12-19 11:44:01 +01:00
"ollama_host": args.ollama_host,
},
2024-12-16 01:05:49 +01:00
}
return app
2024-12-19 11:44:01 +01:00
def main():
2024-12-16 01:05:49 +01:00
args = parse_args()
import uvicorn
2024-12-19 11:44:01 +01:00
2024-12-16 01:05:49 +01:00
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)
2024-12-24 10:35:00 +01:00
if __name__ == "__main__":
main()