LightRAG/lightrag/api/lightrag_server.py

642 lines
23 KiB
Python
Raw Normal View History

"""
LightRAG FastAPI Server
"""
from fastapi import FastAPI, Depends, HTTPException, status
import asyncio
import os
import logging
import logging.config
import uvicorn
import pipmaster as pm
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
2025-01-17 02:34:29 +01:00
from dotenv import load_dotenv
2025-03-03 12:24:49 +01:00
from lightrag.api.utils_api import (
get_combined_auth_dependency,
parse_args,
get_default_host,
display_splash_screen,
check_env_file,
)
import sys
from lightrag import LightRAG, __version__ as core_version
from lightrag.api import __api_version__
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.utils import EmbeddingFunc
2025-03-03 12:24:49 +01:00
from lightrag.api.routers.document_routes import (
DocumentManager,
create_document_routes,
run_scanning_process,
)
2025-03-03 12:24:49 +01:00
from lightrag.api.routers.query_routes import create_query_routes
from lightrag.api.routers.graph_routes import create_graph_routes
from lightrag.api.routers.ollama_api import OllamaAPI
2025-02-18 16:33:57 +08:00
from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
initialize_pipeline_status,
)
from fastapi.security import OAuth2PasswordRequestForm
from lightrag.api.auth import auth_handler
2025-02-01 15:22:40 +08:00
# Load environment variables
2025-03-03 12:24:49 +01:00
# Updated to use the .env that is inside the current folder
# This update allows the user to put a different.env file for each lightrag folder
load_dotenv(".env")
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
# Global authentication configuration
auth_configured = bool(auth_handler.accounts)
2025-02-20 04:12:21 +08:00
def create_app(args):
# Setup logging
logger.setLevel(args.log_level)
set_verbose_debug(args.verbose)
# Verify that bindings are correctly setup
if args.llm_binding not in [
"lollms",
"ollama",
"openai",
"openai-ollama",
"azure_openai",
]:
2025-01-10 20:30:58 +01:00
raise Exception("llm binding not supported")
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
2025-01-10 20:30:58 +01:00
raise Exception("embedding binding not supported")
# Set default hosts if not provided
if args.llm_binding_host is None:
args.llm_binding_host = get_default_host(args.llm_binding)
2025-01-26 05:10:57 +08:00
if args.embedding_binding_host is None:
args.embedding_binding_host = get_default_host(args.embedding_binding)
2025-01-11 01:35:49 +01:00
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
2025-01-11 01:37:07 +01:00
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
2025-01-11 01:35:49 +01:00
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
2025-01-11 01:37:07 +01:00
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
2025-01-04 02:23:39 +01:00
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Store background tasks
app.state.background_tasks = set()
try:
# Initialize database connections
await rag.initialize_storages()
2025-02-28 21:35:04 +08:00
await initialize_pipeline_status()
pipeline_status = await get_namespace_data("pipeline_status")
should_start_autoscan = False
async with get_pipeline_status_lock():
# Auto scan documents if enabled
if args.auto_scan_at_startup:
if not pipeline_status.get("autoscanned", False):
pipeline_status["autoscanned"] = True
should_start_autoscan = True
# Only run auto scan when no other process started it first
if should_start_autoscan:
# Create background task
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
logger.info(f"Process {os.getpid()} auto scan task started at startup.")
2025-02-23 16:42:31 +08:00
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
yield
finally:
# Clean up database connections
await rag.finalize_storages()
# Initialize FastAPI
app_kwargs = {
"title": "LightRAG Server API",
"description": "Providing API for LightRAG core, Web UI and Ollama Model Emulation"
2025-01-04 02:23:39 +01:00
+ "(With authentication)"
if api_key
else "",
"version": __api_version__,
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
"docs_url": "/docs", # Explicitly set docs URL
"redoc_url": "/redoc", # Explicitly set redoc URL
"openapi_tags": [{"name": "api"}],
"lifespan": lifespan,
}
2025-03-24 14:30:17 +08:00
# Configure Swagger UI parameters
# Enable persistAuthorization and tryItOutEnabled for better user experience
app_kwargs["swagger_ui_parameters"] = {
"persistAuthorization": True,
"tryItOutEnabled": True,
}
2025-03-24 14:30:17 +08:00
app = FastAPI(**app_kwargs)
2025-01-04 02:23:39 +01:00
def get_cors_origins():
"""Get allowed origins from environment variable
Returns a list of allowed origins, defaults to ["*"] if not set
"""
origins_str = os.getenv("CORS_ORIGINS", "*")
if origins_str == "*":
return ["*"]
return [origin.strip() for origin in origins_str.split(",")]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=get_cors_origins(),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
if args.llm_binding == "openai" or args.embedding_binding == "openai":
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
2025-01-25 16:57:47 +08:00
if args.llm_binding == "azure_openai" or args.embedding_binding == "azure_openai":
from lightrag.llm.azure_openai import (
azure_openai_complete_if_cache,
azure_openai_embed,
)
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
2025-01-26 09:13:11 +08:00
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.ollama import ollama_embed
async def openai_alike_model_complete(
2025-01-19 08:07:26 +08:00
prompt,
system_prompt=None,
history_messages=None,
2025-01-19 08:07:26 +08:00
keyword_extraction=False,
**kwargs,
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
kwargs["temperature"] = args.temperature
return await openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=args.llm_binding_api_key,
**kwargs,
)
async def azure_openai_model_complete(
2025-01-19 08:07:26 +08:00
prompt,
system_prompt=None,
history_messages=None,
2025-01-19 08:07:26 +08:00
keyword_extraction=False,
**kwargs,
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
kwargs["temperature"] = args.temperature
return await azure_openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
**kwargs,
)
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
2025-01-20 00:26:28 +01:00
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
2025-01-20 00:26:28 +01:00
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "ollama"
else azure_openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
2025-01-20 00:26:28 +01:00
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "azure_openai"
else openai_embed(
texts,
2025-02-14 00:09:32 +01:00
model=args.embedding_model,
base_url=args.embedding_binding_host,
2025-01-20 00:26:28 +01:00
api_key=args.embedding_binding_api_key,
),
)
# Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete
2025-01-11 01:37:07 +01:00
if args.llm_binding == "lollms"
else ollama_model_complete
if args.llm_binding == "ollama"
else openai_alike_model_complete,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
2025-01-20 00:26:28 +01:00
"api_key": args.llm_binding_api_key,
}
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {},
embedding_func=embedding_func,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
2025-03-24 02:02:34 +08:00
max_parallel_insert=args.max_parallel_insert,
2025-01-19 08:07:26 +08:00
)
else: # azure_openai
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=azure_openai_model_complete,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs={
"timeout": args.timeout,
},
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
embedding_func=embedding_func,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
2025-03-24 02:02:34 +08:00
max_parallel_insert=args.max_parallel_insert,
)
# Add routes
app.include_router(create_document_routes(rag, doc_manager, api_key))
app.include_router(create_query_routes(rag, api_key, args.top_k))
app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
app.include_router(ollama_api.router, prefix="/api")
2025-03-18 16:18:37 +08:00
@app.get("/")
async def redirect_to_webui():
"""Redirect root path to /webui"""
return RedirectResponse(url="/webui")
2025-01-27 02:10:24 +01:00
@app.get("/auth-status")
async def get_auth_status():
"""Get authentication status and guest token if auth is not configured"""
if not auth_handler.accounts:
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(
2025-03-18 03:30:43 +08:00
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"auth_configured": False,
"access_token": guest_token,
"token_type": "bearer",
"auth_mode": "disabled",
2025-03-18 03:30:43 +08:00
"message": "Authentication is disabled. Using guest access.",
"core_version": core_version,
"api_version": __api_version__,
}
2025-03-18 03:30:43 +08:00
return {
"auth_configured": True,
"auth_mode": "enabled",
"core_version": core_version,
"api_version": __api_version__,
}
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
if not auth_handler.accounts:
# Authentication not configured, return guest token
guest_token = auth_handler.create_token(
2025-03-18 03:30:43 +08:00
username="guest", role="guest", metadata={"auth_mode": "disabled"}
)
return {
"access_token": guest_token,
"token_type": "bearer",
"auth_mode": "disabled",
2025-03-18 03:30:43 +08:00
"message": "Authentication is disabled. Using guest access.",
"core_version": core_version,
"api_version": __api_version__,
}
username = form_data.username
if auth_handler.accounts.get(username) != form_data.password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect credentials"
)
# Regular user login
user_token = auth_handler.create_token(
2025-03-18 03:30:43 +08:00
username=username, role="user", metadata={"auth_mode": "enabled"}
)
return {
"access_token": user_token,
"token_type": "bearer",
2025-03-18 03:30:43 +08:00
"auth_mode": "enabled",
"core_version": core_version,
"api_version": __api_version__,
}
@app.get("/health", dependencies=[Depends(combined_auth)])
async def get_status():
"""Get current system status"""
try:
pipeline_status = await get_namespace_data("pipeline_status")
2025-03-26 17:30:06 +08:00
if not auth_configured:
auth_mode = "disabled"
else:
auth_mode = "enabled"
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"configuration": {
# LLM configuration binding/host address (if applicable)/model (if applicable)
"llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"kv_storage": args.kv_storage,
"doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage,
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
},
"core_version": core_version,
"api_version": __api_version__,
"auth_mode": auth_mode,
2025-03-26 17:30:06 +08:00
"pipeline_busy": pipeline_status.get("busy", False),
}
except Exception as e:
logger.error(f"Error getting health status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
2025-01-24 21:01:34 +01:00
# Custom StaticFiles class to prevent caching of HTML files
class NoCacheStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
2025-03-12 18:55:15 +08:00
if path.endswith(".html"):
response.headers["Cache-Control"] = (
"no-cache, no-store, must-revalidate"
)
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
# Webui mount webui/index.html
2025-02-17 01:14:33 +08:00
static_dir = Path(__file__).parent / "webui"
2025-01-24 13:50:06 +01:00
static_dir.mkdir(exist_ok=True)
2025-02-20 04:12:21 +08:00
app.mount(
"/webui",
NoCacheStaticFiles(directory=static_dir, html=True, check_dir=True),
2025-02-20 04:12:21 +08:00
name="webui",
)
return app
2025-01-24 21:01:34 +01:00
def get_application(args=None):
"""Factory function for creating the FastAPI application"""
if args is None:
args = parse_args()
return create_app(args)
def configure_logging():
"""Configure logging for uvicorn startup"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger = logging.getLogger(logger_name)
logger.handlers = []
logger.filters = []
2025-02-26 18:11:16 +08:00
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
print(f"\nLightRAG log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
2025-02-28 21:35:04 +08:00
2025-02-26 18:11:16 +08:00
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
2025-02-26 18:11:16 +08:00
"handlers": {
"console": {
2025-02-26 18:11:16 +08:00
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
2025-02-26 18:11:16 +08:00
"loggers": {
# Configure all uvicorn related loggers
"uvicorn": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
2025-02-26 18:11:16 +08:00
"uvicorn.access": {
"handlers": ["console", "file"],
2025-02-26 18:11:16 +08:00
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
"uvicorn.error": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
2025-02-26 18:11:16 +08:00
"lightrag": {
"handlers": ["console", "file"],
2025-02-26 18:11:16 +08:00
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
},
2025-02-26 18:11:16 +08:00
"filters": {
"path_filter": {
"()": "lightrag.utils.LightragPathFilter",
2025-02-26 18:11:16 +08:00
},
},
2025-02-26 18:11:16 +08:00
}
)
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"uvicorn",
"tiktoken",
"fastapi",
# Add other required packages here
]
2025-03-02 00:13:11 +08:00
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
2025-03-02 00:13:11 +08:00
def main():
# Check if running under Gunicorn
2025-02-27 19:05:51 +08:00
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application
print("Running under Gunicorn - worker management handled by Gunicorn")
return
2025-02-26 18:11:16 +08:00
# Check .env file
if not check_env_file():
sys.exit(1)
# Check and install dependencies
check_and_install_dependencies()
from multiprocessing import freeze_support
2025-02-27 19:05:51 +08:00
freeze_support()
2025-02-26 18:11:16 +08:00
# Configure logging before parsing args
configure_logging()
args = parse_args(is_uvicorn_mode=True)
2025-01-24 21:01:34 +01:00
display_splash_screen(args)
2025-02-27 19:05:51 +08:00
# Create application instance directly instead of using factory function
app = create_app(args)
2025-02-27 19:05:51 +08:00
# Start Uvicorn in single process mode
2025-01-11 01:35:49 +01:00
uvicorn_config = {
"app": app, # Pass application instance directly instead of string path
2025-01-11 01:35:49 +01:00
"host": args.host,
"port": args.port,
"log_config": None, # Disable default config
2025-01-11 01:37:07 +01:00
}
2025-02-27 19:05:51 +08:00
2025-01-11 01:35:49 +01:00
if args.ssl:
2025-01-11 01:37:07 +01:00
uvicorn_config.update(
{
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile,
}
)
2025-02-27 19:05:51 +08:00
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
2025-01-11 01:35:49 +01:00
uvicorn.run(**uvicorn_config)
2024-12-24 10:35:00 +01:00
if __name__ == "__main__":
main()