LightRAG/lightrag/api/lightrag_server.py
yangdx 3c080a9ebf Enhance webui mounting with root endpoint and directory check.
- Added FileResponse for webui root endpoint
- Enabled directory check in StaticFiles mount
- Improved webui static file handling
- Ensured webui directory existence
- Simplified webui access with root endpoint
2025-02-20 04:04:54 +08:00

911 lines
32 KiB
Python

"""
LightRAG FastAPI Server
"""
from fastapi import (
FastAPI,
HTTPException,
Depends,
)
from fastapi.responses import FileResponse
import asyncio
import threading
import os
from fastapi.staticfiles import StaticFiles
import logging
import argparse
from typing import Optional, Dict
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
import sys
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from .utils_api import get_api_key_dependency
from lightrag import LightRAG
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger
from .routers.document_routes import (
DocumentManager,
create_document_routes,
run_scanning_process,
)
from .routers.query_routes import create_query_routes
from .routers.graph_routes import create_graph_routes
from .routers.ollama_api import OllamaAPI, ollama_server_infos
# Load environment variables
try:
load_dotenv(override=True)
except Exception as e:
logger.warning(f"Failed to load .env file: {e}")
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
# Global configuration
global_top_k = 60 # default value
class DefaultRAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
GRAPH_STORAGE = "NetworkXStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
}
return default_hosts.get(
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
) # fallback to ollama if unknown
def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
try:
return value_type(value)
except ValueError:
return default
def display_splash_screen(args: argparse.Namespace) -> None:
"""
Display a colorful splash screen showing LightRAG server configuration
Args:
args: Parsed command line arguments
"""
# Banner
ASCIIColors.cyan(f"""
╔══════════════════════════════════════════════════════════════╗
║ 🚀 LightRAG Server v{__api_version__}
║ Fast, Lightweight RAG Server Implementation ║
╚══════════════════════════════════════════════════════════════╝
""")
# Server Configuration
ASCIIColors.magenta("\n📡 Server Configuration:")
ASCIIColors.white(" ├─ Host: ", end="")
ASCIIColors.yellow(f"{args.host}")
ASCIIColors.white(" ├─ Port: ", end="")
ASCIIColors.yellow(f"{args.port}")
ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
ASCIIColors.yellow(f"{args.ssl}")
ASCIIColors.white(" └─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
if args.ssl:
ASCIIColors.white(" ├─ SSL Cert: ", end="")
ASCIIColors.yellow(f"{args.ssl_certfile}")
ASCIIColors.white(" └─ SSL Key: ", end="")
ASCIIColors.yellow(f"{args.ssl_keyfile}")
# Directory Configuration
ASCIIColors.magenta("\n📂 Directory Configuration:")
ASCIIColors.white(" ├─ Working Directory: ", end="")
ASCIIColors.yellow(f"{args.working_dir}")
ASCIIColors.white(" └─ Input Directory: ", end="")
ASCIIColors.yellow(f"{args.input_dir}")
# LLM Configuration
ASCIIColors.magenta("\n🤖 LLM Configuration:")
ASCIIColors.white(" ├─ Binding: ", end="")
ASCIIColors.yellow(f"{args.llm_binding}")
ASCIIColors.white(" ├─ Host: ", end="")
ASCIIColors.yellow(f"{args.llm_binding_host}")
ASCIIColors.white(" └─ Model: ", end="")
ASCIIColors.yellow(f"{args.llm_model}")
# Embedding Configuration
ASCIIColors.magenta("\n📊 Embedding Configuration:")
ASCIIColors.white(" ├─ Binding: ", end="")
ASCIIColors.yellow(f"{args.embedding_binding}")
ASCIIColors.white(" ├─ Host: ", end="")
ASCIIColors.yellow(f"{args.embedding_binding_host}")
ASCIIColors.white(" ├─ Model: ", end="")
ASCIIColors.yellow(f"{args.embedding_model}")
ASCIIColors.white(" └─ Dimensions: ", end="")
ASCIIColors.yellow(f"{args.embedding_dim}")
# RAG Configuration
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Max Async Operations: ", end="")
ASCIIColors.yellow(f"{args.max_async}")
ASCIIColors.white(" ├─ Max Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_tokens}")
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_embed_tokens}")
ASCIIColors.white(" ├─ Chunk Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_size}")
ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_overlap_size}")
ASCIIColors.white(" ├─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}")
ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
ASCIIColors.yellow(f"{args.cosine_threshold}")
ASCIIColors.white(" └─ Top-K: ", end="")
ASCIIColors.yellow(f"{args.top_k}")
# System Configuration
ASCIIColors.magenta("\n💾 Storage Configuration:")
ASCIIColors.white(" ├─ KV Storage: ", end="")
ASCIIColors.yellow(f"{args.kv_storage}")
ASCIIColors.white(" ├─ Vector Storage: ", end="")
ASCIIColors.yellow(f"{args.vector_storage}")
ASCIIColors.white(" ├─ Graph Storage: ", end="")
ASCIIColors.yellow(f"{args.graph_storage}")
ASCIIColors.white(" └─ Document Status Storage: ", end="")
ASCIIColors.yellow(f"{args.doc_status_storage}")
ASCIIColors.magenta("\n🛠️ System Configuration:")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" └─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
# Server Status
ASCIIColors.green("\n✨ Server starting up...\n")
# Server Access Information
protocol = "https" if args.ssl else "http"
if args.host == "0.0.0.0":
ASCIIColors.magenta("\n🌐 Server Access Information:")
ASCIIColors.white(" ├─ Local Access: ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
ASCIIColors.white(" ├─ Remote Access: ", end="")
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
ASCIIColors.white(" ├─ Alternative Documentation (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
ASCIIColors.white(" └─ WebUI (local): ", end="")
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/webui")
ASCIIColors.yellow("\n📝 Note:")
ASCIIColors.white(""" Since the server is running on 0.0.0.0:
- Use 'localhost' or '127.0.0.1' for local access
- Use your machine's IP address for remote access
- To find your IP address:
• Windows: Run 'ipconfig' in terminal
• Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal
""")
else:
base_url = f"{protocol}://{args.host}:{args.port}"
ASCIIColors.magenta("\n🌐 Server Access Information:")
ASCIIColors.white(" ├─ Base URL: ", end="")
ASCIIColors.yellow(f"{base_url}")
ASCIIColors.white(" ├─ API Documentation: ", end="")
ASCIIColors.yellow(f"{base_url}/docs")
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
ASCIIColors.yellow(f"{base_url}/redoc")
# Usage Examples
ASCIIColors.magenta("\n📚 Quick Start Guide:")
ASCIIColors.cyan("""
1. Access the Swagger UI:
Open your browser and navigate to the API documentation URL above
2. API Authentication:""")
if args.key:
ASCIIColors.cyan(""" Add the following header to your requests:
X-API-Key: <your-api-key>
""")
else:
ASCIIColors.cyan(" No authentication required\n")
ASCIIColors.cyan(""" 3. Basic Operations:
- POST /upload_document: Upload new documents to RAG
- POST /query: Query your document collection
- GET /collections: List available collections
4. Monitor the server:
- Check server logs for detailed operation information
- Use healthcheck endpoint: GET /health
""")
# Security Notice
if args.key:
ASCIIColors.yellow("\n⚠️ Security Notice:")
ASCIIColors.white(""" API Key authentication is enabled.
Make sure to include the X-API-Key header in all your requests.
""")
ASCIIColors.green("Server is ready to accept connections! 🚀\n")
# Ensure splash output flush to system log
sys.stdout.flush()
def parse_args() -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
parser.add_argument(
"--kv-storage",
default=get_env_value(
"LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE
),
help=f"KV storage implementation (default: {DefaultRAGStorageConfig.KV_STORAGE})",
)
parser.add_argument(
"--doc-status-storage",
default=get_env_value(
"LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE
),
help=f"Document status storage implementation (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})",
)
parser.add_argument(
"--graph-storage",
default=get_env_value(
"LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE
),
help=f"Graph storage implementation (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})",
)
parser.add_argument(
"--vector-storage",
default=get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
),
help=f"Vector storage implementation (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})",
)
# Bindings configuration
parser.add_argument(
"--llm-binding",
default=get_env_value("LLM_BINDING", "ollama"),
help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
)
parser.add_argument(
"--embedding-binding",
default=get_env_value("EMBEDDING_BINDING", "ollama"),
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
)
# Server configuration
parser.add_argument(
"--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)",
)
parser.add_argument(
"--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)",
)
# Directory configuration
parser.add_argument(
"--working-dir",
default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: from env or ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: from env or ./inputs)",
)
# LLM Model configuration
parser.add_argument(
"--llm-binding-host",
default=get_env_value("LLM_BINDING_HOST", None),
help="LLM server host URL. If not provided, defaults based on llm-binding:\n"
+ "- ollama: http://localhost:11434\n"
+ "- lollms: http://localhost:9600\n"
+ "- openai: https://api.openai.com/v1",
)
default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None)
parser.add_argument(
"--llm-binding-api-key",
default=default_llm_api_key,
help="llm server API key (default: from env or empty string)",
)
parser.add_argument(
"--llm-model",
default=get_env_value("LLM_MODEL", "mistral-nemo:latest"),
help="LLM model name (default: from env or mistral-nemo:latest)",
)
# Embedding model configuration
parser.add_argument(
"--embedding-binding-host",
default=get_env_value("EMBEDDING_BINDING_HOST", None),
help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n"
+ "- ollama: http://localhost:11434\n"
+ "- lollms: http://localhost:9600\n"
+ "- openai: https://api.openai.com/v1",
)
default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
parser.add_argument(
"--embedding-binding-api-key",
default=default_embedding_api_key,
help="embedding server API key (default: from env or empty string)",
)
parser.add_argument(
"--embedding-model",
default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"),
help="Embedding model name (default: from env or bge-m3:latest)",
)
parser.add_argument(
"--chunk_size",
default=get_env_value("CHUNK_SIZE", 1200),
help="chunk chunk size default 1200",
)
parser.add_argument(
"--chunk_overlap_size",
default=get_env_value("CHUNK_OVERLAP_SIZE", 100),
help="chunk overlap size default 100",
)
def timeout_type(value):
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: from env or 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=get_env_value("EMBEDDING_DIM", 1024, int),
help="Embedding dimensions (default: from env or 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=get_env_value("MAX_EMBED_TOKENS", 8192, int),
help="Maximum embedding token size (default: from env or 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: from env or INFO)",
)
parser.add_argument(
"--key",
type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access",
)
# Optional https parameters
parser.add_argument(
"--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)",
)
parser.add_argument(
"--ssl-certfile",
default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)",
)
parser.add_argument(
"--auto-scan-at-startup",
action="store_true",
default=False,
help="Enable automatic scanning when the program starts",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)",
)
parser.add_argument(
"--cosine-threshold",
type=float,
default=get_env_value("COSINE_THRESHOLD", 0.2, float),
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
type=str,
default=get_env_value(
"SIMULATED_MODEL_NAME", ollama_server_infos.LIGHTRAG_MODEL
),
help="Number of conversation history turns to include (default: from env or 3)",
)
# Namespace
parser.add_argument(
"--namespace-prefix",
type=str,
default=get_env_value("NAMESPACE_PREFIX", ""),
help="Prefix of the namespace",
)
parser.add_argument(
"--verbose",
type=bool,
default=get_env_value("VERBOSE", False, bool),
help="Verbose debug output(default: from env or false)",
)
args = parser.parse_args()
# convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args
def create_app(args):
# Set global top_k
global global_top_k
global_top_k = args.top_k # save top_k from args
# Initialize verbose debug setting
from lightrag.utils import set_verbose_debug
set_verbose_debug(args.verbose)
# Verify that bindings are correctly setup
if args.llm_binding not in [
"lollms",
"ollama",
"openai",
"openai-ollama",
"azure_openai",
]:
raise Exception("llm binding not supported")
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
raise Exception("embedding binding not supported")
# Set default hosts if not provided
if args.llm_binding_host is None:
args.llm_binding_host = get_default_host(args.llm_binding)
if args.embedding_binding_host is None:
args.embedding_binding_host = get_default_host(args.embedding_binding)
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Store background tasks
app.state.background_tasks = set()
try:
# Initialize database connections
await rag.initialize_storages()
# Auto scan documents if enabled
if args.auto_scan_at_startup:
# Start scanning in background
with progress_lock:
if not scan_progress["is_scanning"]:
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Create background task
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.info(
f"Started background scanning of documents from {args.input_dir}"
)
else:
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
yield
finally:
# Clean up database connections
await rag.finalize_storages()
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version=__api_version__,
openapi_tags=[{"name": "api"}],
lifespan=lifespan,
)
def get_cors_origins():
"""Get allowed origins from environment variable
Returns a list of allowed origins, defaults to ["*"] if not set
"""
origins_str = os.getenv("CORS_ORIGINS", "*")
if origins_str == "*":
return ["*"]
return [origin.strip() for origin in origins_str.split(",")]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=get_cors_origins(),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
if args.llm_binding == "openai" or args.embedding_binding == "openai":
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
if args.llm_binding == "azure_openai" or args.embedding_binding == "azure_openai":
from lightrag.llm.azure_openai import (
azure_openai_complete_if_cache,
azure_openai_embed,
)
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.ollama import ollama_embed
async def openai_alike_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
return await openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=args.llm_binding_api_key,
**kwargs,
)
async def azure_openai_model_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
if history_messages is None:
history_messages = []
return await azure_openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
**kwargs,
)
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "ollama"
else azure_openai_embed(
texts,
model=args.embedding_model, # no host is used for openai,
api_key=args.embedding_binding_api_key,
)
if args.embedding_binding == "azure_openai"
else openai_embed(
texts,
model=args.embedding_model,
base_url=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
),
)
# Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai-ollama"]:
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_model_complete
if args.llm_binding == "ollama"
else openai_alike_model_complete,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
"api_key": args.llm_binding_api_key,
}
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {},
embedding_func=embedding_func,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
log_level=args.log_level,
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
)
else:
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=azure_openai_model_complete
if args.llm_binding == "azure_openai"
else openai_alike_model_complete,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs={
"timeout": args.timeout,
},
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
embedding_func=embedding_func,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=False, # set to True for debuging to reduce llm fee
embedding_cache_config={
"enabled": True,
"similarity_threshold": 0.95,
"use_llm_check": False,
},
log_level=args.log_level,
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
)
# Add routes
app.include_router(create_document_routes(rag, doc_manager, api_key))
app.include_router(create_query_routes(rag, api_key, args.top_k))
app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k)
app.include_router(ollama_api.router, prefix="/api")
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"configuration": {
# LLM configuration binding/host address (if applicable)/model (if applicable)
"llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"kv_storage": args.kv_storage,
"doc_status_storage": args.doc_status_storage,
"graph_storage": args.graph_storage,
"vector_storage": args.vector_storage,
},
}
# Webui mount webui/index.html
static_dir = Path(__file__).parent / "webui"
static_dir.mkdir(exist_ok=True)
app.mount("/webui", StaticFiles(directory=static_dir, html=True, check_dir=True), name="webui")
@app.get("/webui/")
async def webui_root():
return FileResponse(static_dir / "index.html")
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
display_splash_screen(args)
uvicorn_config = {
"app": app,
"host": args.host,
"port": args.port,
}
if args.ssl:
uvicorn_config.update(
{
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile,
}
)
uvicorn.run(**uvicorn_config)
if __name__ == "__main__":
main()