LightRAG/api/ollama_lightrag_server.py

272 lines
9.5 KiB
Python
Raw Normal View History

2024-12-16 01:05:49 +01:00
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import asyncio
import os
import logging
import argparse
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc
from typing import Optional, List
from enum import Enum
import io
def parse_args():
parser = argparse.ArgumentParser(
description="""
LightRAG FastAPI Server
======================
A REST API server for text querying using LightRAG. Supports multiple search modes,
streaming responses, and document management.
Features:
- Multiple search modes (naive, local, global, hybrid)
- Streaming and non-streaming responses
- Document insertion and management
- Configurable model parameters
- REST API with automatic documentation
""",
formatter_class=argparse.RawDescriptionHelpFormatter
)
# Server configuration
parser.add_argument('--host', default='0.0.0.0', help='Server host (default: 0.0.0.0)')
parser.add_argument('--port', type=int, default=8000, help='Server port (default: 8000)')
# Model configuration
parser.add_argument('--model', default='gemma2:2b', help='LLM model name (default: gemma2:2b)')
parser.add_argument('--embedding-model', default='nomic-embed-text', help='Embedding model name (default: nomic-embed-text)')
parser.add_argument('--ollama-host', default='http://localhost:11434', help='Ollama host URL (default: http://localhost:11434)')
# RAG configuration
parser.add_argument('--working-dir', default='./dickens', help='Working directory for RAG (default: ./dickens)')
parser.add_argument('--max-async', type=int, default=4, help='Maximum async operations (default: 4)')
parser.add_argument('--max-tokens', type=int, default=32768, help='Maximum token size (default: 32768)')
parser.add_argument('--embedding-dim', type=int, default=768, help='Embedding dimensions (default: 768)')
parser.add_argument('--max-embed-tokens', type=int, default=8192, help='Maximum embedding token size (default: 8192)')
# Input configuration
parser.add_argument('--input-file', default='./book.txt', help='Initial input file to process (default: ./book.txt)')
# Logging configuration
parser.add_argument('--log-level', default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Logging level (default: INFO)')
return parser.parse_args()
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def create_app(args):
# Setup logging
logging.basicConfig(format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level))
# Initialize FastAPI app
app = FastAPI(
title="LightRAG API",
description="""
API for querying text using LightRAG.
Configuration:
- Model: {model}
- Embedding Model: {embed_model}
- Working Directory: {work_dir}
- Max Tokens: {max_tokens}
""".format(
model=args.model,
embed_model=args.embedding_model,
work_dir=args.working_dir,
max_tokens=args.max_tokens
)
)
# Create working directory if it doesn't exist
if not os.path.exists(args.working_dir):
os.makedirs(args.working_dir)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=ollama_model_complete,
llm_model_name=args.model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={"host": args.ollama_host, "options": {"num_ctx": args.max_tokens}},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: ollama_embedding(
texts, embed_model=args.embedding_model, host=args.ollama_host
),
),
)
@app.on_event("startup")
async def startup_event():
try:
with open(args.input_file, "r", encoding="utf-8") as f:
rag.insert(f.read())
except FileNotFoundError:
logging.warning(f"Input file {args.input_file} not found. Please ensure the file exists before querying.")
@app.post("/query", response_model=QueryResponse)
async def query_text(request: QueryRequest):
try:
response = rag.query(
request.query,
param=QueryParam(mode=request.mode, stream=request.stream)
)
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream")
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
request.query,
param=QueryParam(mode=request.mode, stream=True)
)
async def stream_generator():
async for chunk in response:
yield chunk
return stream_generator()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/text", response_model=InsertResponse)
async def insert_text(request: InsertTextRequest):
try:
rag.insert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=len(rag)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/file", response_model=InsertResponse)
async def insert_file(
file: UploadFile = File(...),
description: str = Form(None)
):
try:
content = await file.read()
if file.filename.endswith(('.txt', '.md')):
text = content.decode('utf-8')
rag.insert(text)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported"
)
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=len(rag)
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/batch", response_model=InsertResponse)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = await file.read()
if file.filename.endswith(('.txt', '.md')):
text = content.decode('utf-8')
rag.insert(text)
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
document_count=len(rag)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/documents", response_model=InsertResponse)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"configuration": {
"model": args.model,
"embedding_model": args.embedding_model,
"working_dir": args.working_dir,
"max_tokens": args.max_tokens,
"ollama_host": args.ollama_host
}
}
return app
if __name__ == "__main__":
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn.run(app, host=args.host, port=args.port)