mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-05 16:10:36 +00:00

- Add role-based token system with metadata support - Implement automatic guest mode for unconfigured authentication - Create new /auth-status endpoint for authentication status checking - Modify frontend to auto-detect auth status and bypass login when appropriate - Add guest mode indicator in site header for better UX This change allows users to automatically access the system without manual login when authentication is not configured, while maintaining secure authentication when credentials are properly set up.
608 lines
22 KiB
Python
608 lines
22 KiB
Python
"""
|
|
LightRAG FastAPI Server
|
|
"""
|
|
|
|
from fastapi import FastAPI, Depends, HTTPException, status
|
|
import asyncio
|
|
import os
|
|
import logging
|
|
import logging.config
|
|
import uvicorn
|
|
import pipmaster as pm
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pathlib import Path
|
|
import configparser
|
|
from ascii_colors import ASCIIColors
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from contextlib import asynccontextmanager
|
|
from dotenv import load_dotenv
|
|
from lightrag.api.utils_api import (
|
|
get_api_key_dependency,
|
|
parse_args,
|
|
get_default_host,
|
|
display_splash_screen,
|
|
)
|
|
from lightrag import LightRAG
|
|
from lightrag.types import GPTKeywordExtractionFormat
|
|
from lightrag.api import __api_version__
|
|
from lightrag.utils import EmbeddingFunc
|
|
from lightrag.api.routers.document_routes import (
|
|
DocumentManager,
|
|
create_document_routes,
|
|
run_scanning_process,
|
|
)
|
|
from lightrag.api.routers.query_routes import create_query_routes
|
|
from lightrag.api.routers.graph_routes import create_graph_routes
|
|
from lightrag.api.routers.ollama_api import OllamaAPI
|
|
|
|
from lightrag.utils import logger, set_verbose_debug
|
|
from lightrag.kg.shared_storage import (
|
|
get_namespace_data,
|
|
get_pipeline_status_lock,
|
|
initialize_pipeline_status,
|
|
get_all_update_flags_status,
|
|
)
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from .auth import auth_handler
|
|
|
|
# Load environment variables
|
|
# Updated to use the .env that is inside the current folder
|
|
# This update allows the user to put a different.env file for each lightrag folder
|
|
load_dotenv(".env", override=True)
|
|
|
|
# Initialize config parser
|
|
config = configparser.ConfigParser()
|
|
config.read("config.ini")
|
|
|
|
|
|
def create_app(args):
|
|
# Setup logging
|
|
logger.setLevel(args.log_level)
|
|
set_verbose_debug(args.verbose)
|
|
|
|
# Verify that bindings are correctly setup
|
|
if args.llm_binding not in [
|
|
"lollms",
|
|
"ollama",
|
|
"openai",
|
|
"openai-ollama",
|
|
"azure_openai",
|
|
]:
|
|
raise Exception("llm binding not supported")
|
|
|
|
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
|
|
raise Exception("embedding binding not supported")
|
|
|
|
# Set default hosts if not provided
|
|
if args.llm_binding_host is None:
|
|
args.llm_binding_host = get_default_host(args.llm_binding)
|
|
|
|
if args.embedding_binding_host is None:
|
|
args.embedding_binding_host = get_default_host(args.embedding_binding)
|
|
|
|
# Add SSL validation
|
|
if args.ssl:
|
|
if not args.ssl_certfile or not args.ssl_keyfile:
|
|
raise Exception(
|
|
"SSL certificate and key files must be provided when SSL is enabled"
|
|
)
|
|
if not os.path.exists(args.ssl_certfile):
|
|
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
|
|
if not os.path.exists(args.ssl_keyfile):
|
|
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
|
|
|
# Check if API key is provided either through env var or args
|
|
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
|
|
|
# Initialize document manager
|
|
doc_manager = DocumentManager(args.input_dir)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifespan context manager for startup and shutdown events"""
|
|
# Store background tasks
|
|
app.state.background_tasks = set()
|
|
|
|
try:
|
|
# Initialize database connections
|
|
await rag.initialize_storages()
|
|
|
|
await initialize_pipeline_status()
|
|
pipeline_status = await get_namespace_data("pipeline_status")
|
|
|
|
should_start_autoscan = False
|
|
async with get_pipeline_status_lock():
|
|
# Auto scan documents if enabled
|
|
if args.auto_scan_at_startup:
|
|
if not pipeline_status.get("autoscanned", False):
|
|
pipeline_status["autoscanned"] = True
|
|
should_start_autoscan = True
|
|
|
|
# Only run auto scan when no other process started it first
|
|
if should_start_autoscan:
|
|
# Create background task
|
|
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
|
app.state.background_tasks.add(task)
|
|
task.add_done_callback(app.state.background_tasks.discard)
|
|
logger.info(f"Process {os.getpid()} auto scan task started at startup.")
|
|
|
|
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
|
|
|
yield
|
|
|
|
finally:
|
|
# Clean up database connections
|
|
await rag.finalize_storages()
|
|
|
|
# Initialize FastAPI
|
|
app = FastAPI(
|
|
title="LightRAG API",
|
|
description="API for querying text using LightRAG with separate storage and input directories"
|
|
+ "(With authentication)"
|
|
if api_key
|
|
else "",
|
|
version=__api_version__,
|
|
openapi_url="/openapi.json", # Explicitly set OpenAPI schema URL
|
|
docs_url="/docs", # Explicitly set docs URL
|
|
redoc_url="/redoc", # Explicitly set redoc URL
|
|
openapi_tags=[{"name": "api"}],
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
def get_cors_origins():
|
|
"""Get allowed origins from environment variable
|
|
Returns a list of allowed origins, defaults to ["*"] if not set
|
|
"""
|
|
origins_str = os.getenv("CORS_ORIGINS", "*")
|
|
if origins_str == "*":
|
|
return ["*"]
|
|
return [origin.strip() for origin in origins_str.split(",")]
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=get_cors_origins(),
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Create the optional API key dependency
|
|
optional_api_key = get_api_key_dependency(api_key)
|
|
|
|
# Create working directory if it doesn't exist
|
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
|
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
|
|
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
|
|
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
|
|
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
|
if args.llm_binding == "openai" or args.embedding_binding == "openai":
|
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
|
if args.llm_binding == "azure_openai" or args.embedding_binding == "azure_openai":
|
|
from lightrag.llm.azure_openai import (
|
|
azure_openai_complete_if_cache,
|
|
azure_openai_embed,
|
|
)
|
|
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
|
from lightrag.llm.openai import openai_complete_if_cache
|
|
from lightrag.llm.ollama import ollama_embed
|
|
|
|
async def openai_alike_model_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
if history_messages is None:
|
|
history_messages = []
|
|
return await openai_complete_if_cache(
|
|
args.llm_model,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
base_url=args.llm_binding_host,
|
|
api_key=args.llm_binding_api_key,
|
|
**kwargs,
|
|
)
|
|
|
|
async def azure_openai_model_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
if history_messages is None:
|
|
history_messages = []
|
|
return await azure_openai_complete_if_cache(
|
|
args.llm_model,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
base_url=args.llm_binding_host,
|
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
|
|
**kwargs,
|
|
)
|
|
|
|
embedding_func = EmbeddingFunc(
|
|
embedding_dim=args.embedding_dim,
|
|
max_token_size=args.max_embed_tokens,
|
|
func=lambda texts: lollms_embed(
|
|
texts,
|
|
embed_model=args.embedding_model,
|
|
host=args.embedding_binding_host,
|
|
api_key=args.embedding_binding_api_key,
|
|
)
|
|
if args.embedding_binding == "lollms"
|
|
else ollama_embed(
|
|
texts,
|
|
embed_model=args.embedding_model,
|
|
host=args.embedding_binding_host,
|
|
api_key=args.embedding_binding_api_key,
|
|
)
|
|
if args.embedding_binding == "ollama"
|
|
else azure_openai_embed(
|
|
texts,
|
|
model=args.embedding_model, # no host is used for openai,
|
|
api_key=args.embedding_binding_api_key,
|
|
)
|
|
if args.embedding_binding == "azure_openai"
|
|
else openai_embed(
|
|
texts,
|
|
model=args.embedding_model,
|
|
base_url=args.embedding_binding_host,
|
|
api_key=args.embedding_binding_api_key,
|
|
),
|
|
)
|
|
|
|
# Initialize RAG
|
|
if args.llm_binding in ["lollms", "ollama", "openai"]:
|
|
rag = LightRAG(
|
|
working_dir=args.working_dir,
|
|
llm_model_func=lollms_model_complete
|
|
if args.llm_binding == "lollms"
|
|
else ollama_model_complete
|
|
if args.llm_binding == "ollama"
|
|
else openai_alike_model_complete,
|
|
llm_model_name=args.llm_model,
|
|
llm_model_max_async=args.max_async,
|
|
llm_model_max_token_size=args.max_tokens,
|
|
chunk_token_size=int(args.chunk_size),
|
|
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
|
llm_model_kwargs={
|
|
"host": args.llm_binding_host,
|
|
"timeout": args.timeout,
|
|
"options": {"num_ctx": args.max_tokens},
|
|
"api_key": args.llm_binding_api_key,
|
|
}
|
|
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
|
else {},
|
|
embedding_func=embedding_func,
|
|
kv_storage=args.kv_storage,
|
|
graph_storage=args.graph_storage,
|
|
vector_storage=args.vector_storage,
|
|
doc_status_storage=args.doc_status_storage,
|
|
vector_db_storage_cls_kwargs={
|
|
"cosine_better_than_threshold": args.cosine_threshold
|
|
},
|
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
|
embedding_cache_config={
|
|
"enabled": True,
|
|
"similarity_threshold": 0.95,
|
|
"use_llm_check": False,
|
|
},
|
|
namespace_prefix=args.namespace_prefix,
|
|
auto_manage_storages_states=False,
|
|
)
|
|
else: # azure_openai
|
|
rag = LightRAG(
|
|
working_dir=args.working_dir,
|
|
llm_model_func=azure_openai_model_complete,
|
|
chunk_token_size=int(args.chunk_size),
|
|
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
|
llm_model_kwargs={
|
|
"timeout": args.timeout,
|
|
},
|
|
llm_model_name=args.llm_model,
|
|
llm_model_max_async=args.max_async,
|
|
llm_model_max_token_size=args.max_tokens,
|
|
embedding_func=embedding_func,
|
|
kv_storage=args.kv_storage,
|
|
graph_storage=args.graph_storage,
|
|
vector_storage=args.vector_storage,
|
|
doc_status_storage=args.doc_status_storage,
|
|
vector_db_storage_cls_kwargs={
|
|
"cosine_better_than_threshold": args.cosine_threshold
|
|
},
|
|
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
|
embedding_cache_config={
|
|
"enabled": True,
|
|
"similarity_threshold": 0.95,
|
|
"use_llm_check": False,
|
|
},
|
|
namespace_prefix=args.namespace_prefix,
|
|
auto_manage_storages_states=False,
|
|
)
|
|
|
|
# Add routes
|
|
app.include_router(create_document_routes(rag, doc_manager, api_key))
|
|
app.include_router(create_query_routes(rag, api_key, args.top_k))
|
|
app.include_router(create_graph_routes(rag, api_key))
|
|
|
|
# Add Ollama API routes
|
|
ollama_api = OllamaAPI(rag, top_k=args.top_k)
|
|
app.include_router(ollama_api.router, prefix="/api")
|
|
|
|
@app.get("/auth-status", dependencies=[Depends(optional_api_key)])
|
|
async def get_auth_status():
|
|
"""Get authentication status and guest token if auth is not configured"""
|
|
username = os.getenv("AUTH_USERNAME")
|
|
password = os.getenv("AUTH_PASSWORD")
|
|
|
|
if not (username and password):
|
|
# Authentication not configured, return guest token
|
|
guest_token = auth_handler.create_token(
|
|
username="guest",
|
|
role="guest",
|
|
metadata={"auth_mode": "disabled"}
|
|
)
|
|
return {
|
|
"auth_configured": False,
|
|
"access_token": guest_token,
|
|
"token_type": "bearer",
|
|
"auth_mode": "disabled",
|
|
"message": "Authentication is disabled. Using guest access."
|
|
}
|
|
|
|
return {
|
|
"auth_configured": True,
|
|
"auth_mode": "enabled"
|
|
}
|
|
|
|
@app.post("/login", dependencies=[Depends(optional_api_key)])
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
username = os.getenv("AUTH_USERNAME")
|
|
password = os.getenv("AUTH_PASSWORD")
|
|
|
|
if not (username and password):
|
|
# Authentication not configured, return guest token
|
|
guest_token = auth_handler.create_token(
|
|
username="guest",
|
|
role="guest",
|
|
metadata={"auth_mode": "disabled"}
|
|
)
|
|
return {
|
|
"access_token": guest_token,
|
|
"token_type": "bearer",
|
|
"auth_mode": "disabled",
|
|
"message": "Authentication is disabled. Using guest access."
|
|
}
|
|
|
|
if form_data.username != username or form_data.password != password:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
|
|
)
|
|
|
|
# Regular user login
|
|
user_token = auth_handler.create_token(
|
|
username=username,
|
|
role="user",
|
|
metadata={"auth_mode": "enabled"}
|
|
)
|
|
return {
|
|
"access_token": user_token,
|
|
"token_type": "bearer",
|
|
"auth_mode": "enabled"
|
|
}
|
|
|
|
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
|
async def get_status():
|
|
"""Get current system status"""
|
|
# Get update flags status for all namespaces
|
|
update_status = await get_all_update_flags_status()
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"working_directory": str(args.working_dir),
|
|
"input_directory": str(args.input_dir),
|
|
"configuration": {
|
|
# LLM configuration binding/host address (if applicable)/model (if applicable)
|
|
"llm_binding": args.llm_binding,
|
|
"llm_binding_host": args.llm_binding_host,
|
|
"llm_model": args.llm_model,
|
|
# embedding model configuration binding/host address (if applicable)/model (if applicable)
|
|
"embedding_binding": args.embedding_binding,
|
|
"embedding_binding_host": args.embedding_binding_host,
|
|
"embedding_model": args.embedding_model,
|
|
"max_tokens": args.max_tokens,
|
|
"kv_storage": args.kv_storage,
|
|
"doc_status_storage": args.doc_status_storage,
|
|
"graph_storage": args.graph_storage,
|
|
"vector_storage": args.vector_storage,
|
|
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
|
},
|
|
"update_status": update_status,
|
|
}
|
|
|
|
# Custom StaticFiles class to prevent caching of HTML files
|
|
class NoCacheStaticFiles(StaticFiles):
|
|
async def get_response(self, path: str, scope):
|
|
response = await super().get_response(path, scope)
|
|
if path.endswith(".html"):
|
|
response.headers["Cache-Control"] = (
|
|
"no-cache, no-store, must-revalidate"
|
|
)
|
|
response.headers["Pragma"] = "no-cache"
|
|
response.headers["Expires"] = "0"
|
|
return response
|
|
|
|
# Webui mount webui/index.html
|
|
static_dir = Path(__file__).parent / "webui"
|
|
static_dir.mkdir(exist_ok=True)
|
|
app.mount(
|
|
"/webui",
|
|
NoCacheStaticFiles(directory=static_dir, html=True, check_dir=True),
|
|
name="webui",
|
|
)
|
|
|
|
return app
|
|
|
|
|
|
def get_application(args=None):
|
|
"""Factory function for creating the FastAPI application"""
|
|
if args is None:
|
|
args = parse_args()
|
|
return create_app(args)
|
|
|
|
|
|
def configure_logging():
|
|
"""Configure logging for uvicorn startup"""
|
|
|
|
# Reset any existing handlers to ensure clean configuration
|
|
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
|
|
logger = logging.getLogger(logger_name)
|
|
logger.handlers = []
|
|
logger.filters = []
|
|
|
|
# Get log directory path from environment variable
|
|
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
|
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
|
|
|
print(f"\nLightRAG log file: {log_file_path}\n")
|
|
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
|
|
|
|
# Get log file max size and backup count from environment variables
|
|
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
|
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
|
|
|
logging.config.dictConfig(
|
|
{
|
|
"version": 1,
|
|
"disable_existing_loggers": False,
|
|
"formatters": {
|
|
"default": {
|
|
"format": "%(levelname)s: %(message)s",
|
|
},
|
|
"detailed": {
|
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
},
|
|
},
|
|
"handlers": {
|
|
"console": {
|
|
"formatter": "default",
|
|
"class": "logging.StreamHandler",
|
|
"stream": "ext://sys.stderr",
|
|
},
|
|
"file": {
|
|
"formatter": "detailed",
|
|
"class": "logging.handlers.RotatingFileHandler",
|
|
"filename": log_file_path,
|
|
"maxBytes": log_max_bytes,
|
|
"backupCount": log_backup_count,
|
|
"encoding": "utf-8",
|
|
},
|
|
},
|
|
"loggers": {
|
|
# Configure all uvicorn related loggers
|
|
"uvicorn": {
|
|
"handlers": ["console", "file"],
|
|
"level": "INFO",
|
|
"propagate": False,
|
|
},
|
|
"uvicorn.access": {
|
|
"handlers": ["console", "file"],
|
|
"level": "INFO",
|
|
"propagate": False,
|
|
"filters": ["path_filter"],
|
|
},
|
|
"uvicorn.error": {
|
|
"handlers": ["console", "file"],
|
|
"level": "INFO",
|
|
"propagate": False,
|
|
},
|
|
"lightrag": {
|
|
"handlers": ["console", "file"],
|
|
"level": "INFO",
|
|
"propagate": False,
|
|
"filters": ["path_filter"],
|
|
},
|
|
},
|
|
"filters": {
|
|
"path_filter": {
|
|
"()": "lightrag.utils.LightragPathFilter",
|
|
},
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
def check_and_install_dependencies():
|
|
"""Check and install required dependencies"""
|
|
required_packages = [
|
|
"uvicorn",
|
|
"tiktoken",
|
|
"fastapi",
|
|
# Add other required packages here
|
|
]
|
|
|
|
for package in required_packages:
|
|
if not pm.is_installed(package):
|
|
print(f"Installing {package}...")
|
|
pm.install(package)
|
|
print(f"{package} installed successfully")
|
|
|
|
|
|
def main():
|
|
# Check if running under Gunicorn
|
|
if "GUNICORN_CMD_ARGS" in os.environ:
|
|
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
|
print("Running under Gunicorn - worker management handled by Gunicorn")
|
|
return
|
|
|
|
# Check and install dependencies
|
|
check_and_install_dependencies()
|
|
|
|
from multiprocessing import freeze_support
|
|
|
|
freeze_support()
|
|
|
|
# Configure logging before parsing args
|
|
configure_logging()
|
|
|
|
args = parse_args(is_uvicorn_mode=True)
|
|
display_splash_screen(args)
|
|
|
|
# Create application instance directly instead of using factory function
|
|
app = create_app(args)
|
|
|
|
# Start Uvicorn in single process mode
|
|
uvicorn_config = {
|
|
"app": app, # Pass application instance directly instead of string path
|
|
"host": args.host,
|
|
"port": args.port,
|
|
"log_config": None, # Disable default config
|
|
}
|
|
|
|
if args.ssl:
|
|
uvicorn_config.update(
|
|
{
|
|
"ssl_certfile": args.ssl_certfile,
|
|
"ssl_keyfile": args.ssl_keyfile,
|
|
}
|
|
)
|
|
|
|
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
|
|
uvicorn.run(**uvicorn_config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|