mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-26 10:20:31 +00:00

- Remove MAX_TOKEN_SUMMARY parameter and related configurations - Eliminate forced token-based truncation in entity/relationship descriptions - Switch to fragment-count based summarization logic using FORCE_LLM_SUMMARY_ON_MERGE - Update FORCE_LLM_SUMMARY_ON_MERGE default from 6 to 4 for better summarization - Clean up documentation, environment examples, and API display code - Preserve backward compatibility by graceful parameter removal This change resolves issues where LLMs were forcibly truncating entity relationship descriptions mid-sentence, leading to incomplete and potentially inaccurate knowledge graph content. The new approach allows LLMs to generate complete descriptions while still providing summarization when multiple fragments need to be merged. Breaking Change: None - parameter removal is backward compatible Fixes: Entity relationship description truncation issues
348 lines
14 KiB
Python
348 lines
14 KiB
Python
"""
|
|
Utility functions for the LightRAG API.
|
|
"""
|
|
|
|
import os
|
|
import argparse
|
|
from typing import Optional, List, Tuple
|
|
import sys
|
|
from ascii_colors import ASCIIColors
|
|
from lightrag.api import __api_version__ as api_version
|
|
from lightrag import __version__ as core_version
|
|
from lightrag.constants import (
|
|
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
|
)
|
|
from fastapi import HTTPException, Security, Request, status
|
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
|
from starlette.status import HTTP_403_FORBIDDEN
|
|
from .auth import auth_handler
|
|
from .config import ollama_server_infos, global_args, get_env_value
|
|
|
|
|
|
def check_env_file():
|
|
"""
|
|
Check if .env file exists and handle user confirmation if needed.
|
|
Returns True if should continue, False if should exit.
|
|
"""
|
|
if not os.path.exists(".env"):
|
|
warning_msg = "Warning: Startup directory must contain .env file for multi-instance support."
|
|
ASCIIColors.yellow(warning_msg)
|
|
|
|
# Check if running in interactive terminal
|
|
if sys.stdin.isatty():
|
|
response = input("Do you want to continue? (yes/no): ")
|
|
if response.lower() != "yes":
|
|
ASCIIColors.red("Server startup cancelled")
|
|
return False
|
|
return True
|
|
|
|
|
|
# Get whitelist paths from global_args, only once during initialization
|
|
whitelist_paths = global_args.whitelist_paths.split(",")
|
|
|
|
# Pre-compile path matching patterns
|
|
whitelist_patterns: List[Tuple[str, bool]] = []
|
|
for path in whitelist_paths:
|
|
path = path.strip()
|
|
if path:
|
|
# If path ends with /*, match all paths with that prefix
|
|
if path.endswith("/*"):
|
|
prefix = path[:-2]
|
|
whitelist_patterns.append((prefix, True)) # (prefix, is_prefix_match)
|
|
else:
|
|
whitelist_patterns.append((path, False)) # (exact_path, is_prefix_match)
|
|
|
|
# Global authentication configuration
|
|
auth_configured = bool(auth_handler.accounts)
|
|
|
|
|
|
def get_combined_auth_dependency(api_key: Optional[str] = None):
|
|
"""
|
|
Create a combined authentication dependency that implements authentication logic
|
|
based on API key, OAuth2 token, and whitelist paths.
|
|
|
|
Args:
|
|
api_key (Optional[str]): API key for validation
|
|
|
|
Returns:
|
|
Callable: A dependency function that implements the authentication logic
|
|
"""
|
|
# Use global whitelist_patterns and auth_configured variables
|
|
# whitelist_patterns and auth_configured are already initialized at module level
|
|
|
|
# Only calculate api_key_configured as it depends on the function parameter
|
|
api_key_configured = bool(api_key)
|
|
|
|
# Create security dependencies with proper descriptions for Swagger UI
|
|
oauth2_scheme = OAuth2PasswordBearer(
|
|
tokenUrl="login", auto_error=False, description="OAuth2 Password Authentication"
|
|
)
|
|
|
|
# If API key is configured, create an API key header security
|
|
api_key_header = None
|
|
if api_key_configured:
|
|
api_key_header = APIKeyHeader(
|
|
name="X-API-Key", auto_error=False, description="API Key Authentication"
|
|
)
|
|
|
|
async def combined_dependency(
|
|
request: Request,
|
|
token: str = Security(oauth2_scheme),
|
|
api_key_header_value: Optional[str] = None
|
|
if api_key_header is None
|
|
else Security(api_key_header),
|
|
):
|
|
# 1. Check if path is in whitelist
|
|
path = request.url.path
|
|
for pattern, is_prefix in whitelist_patterns:
|
|
if (is_prefix and path.startswith(pattern)) or (
|
|
not is_prefix and path == pattern
|
|
):
|
|
return # Whitelist path, allow access
|
|
|
|
# 2. Validate token first if provided in the request (Ensure 401 error if token is invalid)
|
|
if token:
|
|
try:
|
|
token_info = auth_handler.validate_token(token)
|
|
# Accept guest token if no auth is configured
|
|
if not auth_configured and token_info.get("role") == "guest":
|
|
return
|
|
# Accept non-guest token if auth is configured
|
|
if auth_configured and token_info.get("role") != "guest":
|
|
return
|
|
|
|
# Token validation failed, immediately return 401 error
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token. Please login again.",
|
|
)
|
|
except HTTPException as e:
|
|
# If already a 401 error, re-raise it
|
|
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
|
raise
|
|
# For other exceptions, continue processing
|
|
|
|
# 3. Acept all request if no API protection needed
|
|
if not auth_configured and not api_key_configured:
|
|
return
|
|
|
|
# 4. Validate API key if provided and API-Key authentication is configured
|
|
if (
|
|
api_key_configured
|
|
and api_key_header_value
|
|
and api_key_header_value == api_key
|
|
):
|
|
return # API key validation successful
|
|
|
|
### Authentication failed ####
|
|
|
|
# if password authentication is configured but not provided, ensure 401 error if auth_configured
|
|
if auth_configured and not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="No credentials provided. Please login.",
|
|
)
|
|
|
|
# if api key is provided but validation failed
|
|
if api_key_header_value:
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN,
|
|
detail="Invalid API Key",
|
|
)
|
|
|
|
# if api_key_configured but not provided
|
|
if api_key_configured and not api_key_header_value:
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN,
|
|
detail="API Key required",
|
|
)
|
|
|
|
# Otherwise: refuse access and return 403 error
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN,
|
|
detail="API Key required or login authentication required.",
|
|
)
|
|
|
|
return combined_dependency
|
|
|
|
|
|
def display_splash_screen(args: argparse.Namespace) -> None:
|
|
"""
|
|
Display a colorful splash screen showing LightRAG server configuration
|
|
|
|
Args:
|
|
args: Parsed command line arguments
|
|
"""
|
|
# Banner
|
|
# Banner
|
|
top_border = "╔══════════════════════════════════════════════════════════════╗"
|
|
bottom_border = "╚══════════════════════════════════════════════════════════════╝"
|
|
width = len(top_border) - 4 # width inside the borders
|
|
|
|
line1_text = f"LightRAG Server v{core_version}/{api_version}"
|
|
line2_text = "Fast, Lightweight RAG Server Implementation"
|
|
|
|
line1 = f"║ {line1_text.center(width)} ║"
|
|
line2 = f"║ {line2_text.center(width)} ║"
|
|
|
|
banner = f"""
|
|
{top_border}
|
|
{line1}
|
|
{line2}
|
|
{bottom_border}
|
|
"""
|
|
ASCIIColors.cyan(banner)
|
|
|
|
# Server Configuration
|
|
ASCIIColors.magenta("\n📡 Server Configuration:")
|
|
ASCIIColors.white(" ├─ Host: ", end="")
|
|
ASCIIColors.yellow(f"{args.host}")
|
|
ASCIIColors.white(" ├─ Port: ", end="")
|
|
ASCIIColors.yellow(f"{args.port}")
|
|
ASCIIColors.white(" ├─ Workers: ", end="")
|
|
ASCIIColors.yellow(f"{args.workers}")
|
|
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
|
ASCIIColors.yellow(f"{args.cors_origins}")
|
|
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
|
ASCIIColors.yellow(f"{args.ssl}")
|
|
if args.ssl:
|
|
ASCIIColors.white(" ├─ SSL Cert: ", end="")
|
|
ASCIIColors.yellow(f"{args.ssl_certfile}")
|
|
ASCIIColors.white(" ├─ SSL Key: ", end="")
|
|
ASCIIColors.yellow(f"{args.ssl_keyfile}")
|
|
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
|
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
|
ASCIIColors.white(" ├─ Log Level: ", end="")
|
|
ASCIIColors.yellow(f"{args.log_level}")
|
|
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
|
|
ASCIIColors.yellow(f"{args.verbose}")
|
|
ASCIIColors.white(" ├─ History Turns: ", end="")
|
|
ASCIIColors.yellow(f"{args.history_turns}")
|
|
ASCIIColors.white(" ├─ API Key: ", end="")
|
|
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
|
ASCIIColors.white(" └─ JWT Auth: ", end="")
|
|
ASCIIColors.yellow("Enabled" if args.auth_accounts else "Disabled")
|
|
|
|
# Directory Configuration
|
|
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
|
ASCIIColors.white(" ├─ Working Directory: ", end="")
|
|
ASCIIColors.yellow(f"{args.working_dir}")
|
|
ASCIIColors.white(" └─ Input Directory: ", end="")
|
|
ASCIIColors.yellow(f"{args.input_dir}")
|
|
|
|
# LLM Configuration
|
|
ASCIIColors.magenta("\n🤖 LLM Configuration:")
|
|
ASCIIColors.white(" ├─ Binding: ", end="")
|
|
ASCIIColors.yellow(f"{args.llm_binding}")
|
|
ASCIIColors.white(" ├─ Host: ", end="")
|
|
ASCIIColors.yellow(f"{args.llm_binding_host}")
|
|
ASCIIColors.white(" ├─ Model: ", end="")
|
|
ASCIIColors.yellow(f"{args.llm_model}")
|
|
ASCIIColors.white(" ├─ Temperature: ", end="")
|
|
ASCIIColors.yellow(f"{args.temperature}")
|
|
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
|
|
ASCIIColors.yellow(f"{args.max_async}")
|
|
ASCIIColors.white(" ├─ Max Tokens: ", end="")
|
|
ASCIIColors.yellow(f"{args.max_tokens}")
|
|
ASCIIColors.white(" ├─ Timeout: ", end="")
|
|
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
|
ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="")
|
|
ASCIIColors.yellow(f"{args.enable_llm_cache}")
|
|
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
|
ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
|
|
|
|
# Embedding Configuration
|
|
ASCIIColors.magenta("\n📊 Embedding Configuration:")
|
|
ASCIIColors.white(" ├─ Binding: ", end="")
|
|
ASCIIColors.yellow(f"{args.embedding_binding}")
|
|
ASCIIColors.white(" ├─ Host: ", end="")
|
|
ASCIIColors.yellow(f"{args.embedding_binding_host}")
|
|
ASCIIColors.white(" ├─ Model: ", end="")
|
|
ASCIIColors.yellow(f"{args.embedding_model}")
|
|
ASCIIColors.white(" └─ Dimensions: ", end="")
|
|
ASCIIColors.yellow(f"{args.embedding_dim}")
|
|
|
|
# RAG Configuration
|
|
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
|
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
|
ASCIIColors.yellow(f"{args.summary_language}")
|
|
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
|
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
|
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")
|
|
ASCIIColors.yellow(f"{args.max_embed_tokens}")
|
|
ASCIIColors.white(" ├─ Chunk Size: ", end="")
|
|
ASCIIColors.yellow(f"{args.chunk_size}")
|
|
ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="")
|
|
ASCIIColors.yellow(f"{args.chunk_overlap_size}")
|
|
ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
|
|
ASCIIColors.yellow(f"{args.cosine_threshold}")
|
|
ASCIIColors.white(" ├─ Top-K: ", end="")
|
|
ASCIIColors.yellow(f"{args.top_k}")
|
|
ASCIIColors.white(" ├─ Max Token Summary: ", end="")
|
|
ASCIIColors.white(" └─ Force LLM Summary on Merge: ", end="")
|
|
ASCIIColors.yellow(
|
|
f"{get_env_value('FORCE_LLM_SUMMARY_ON_MERGE', DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, int)}"
|
|
)
|
|
|
|
# System Configuration
|
|
ASCIIColors.magenta("\n💾 Storage Configuration:")
|
|
ASCIIColors.white(" ├─ KV Storage: ", end="")
|
|
ASCIIColors.yellow(f"{args.kv_storage}")
|
|
ASCIIColors.white(" ├─ Vector Storage: ", end="")
|
|
ASCIIColors.yellow(f"{args.vector_storage}")
|
|
ASCIIColors.white(" ├─ Graph Storage: ", end="")
|
|
ASCIIColors.yellow(f"{args.graph_storage}")
|
|
ASCIIColors.white(" ├─ Document Status Storage: ", end="")
|
|
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
|
ASCIIColors.white(" └─ Workspace: ", end="")
|
|
ASCIIColors.yellow(f"{args.workspace if args.workspace else '-'}")
|
|
|
|
# Server Status
|
|
ASCIIColors.green("\n✨ Server starting up...\n")
|
|
|
|
# Server Access Information
|
|
protocol = "https" if args.ssl else "http"
|
|
if args.host == "0.0.0.0":
|
|
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
|
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
|
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}")
|
|
ASCIIColors.white(" ├─ Remote Access: ", end="")
|
|
ASCIIColors.yellow(f"{protocol}://<your-ip-address>:{args.port}")
|
|
ASCIIColors.white(" ├─ API Documentation (local): ", end="")
|
|
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/docs")
|
|
ASCIIColors.white(" └─ Alternative Documentation (local): ", end="")
|
|
ASCIIColors.yellow(f"{protocol}://localhost:{args.port}/redoc")
|
|
|
|
ASCIIColors.magenta("\n📝 Note:")
|
|
ASCIIColors.cyan(""" Since the server is running on 0.0.0.0:
|
|
- Use 'localhost' or '127.0.0.1' for local access
|
|
- Use your machine's IP address for remote access
|
|
- To find your IP address:
|
|
• Windows: Run 'ipconfig' in terminal
|
|
• Linux/Mac: Run 'ifconfig' or 'ip addr' in terminal
|
|
""")
|
|
else:
|
|
base_url = f"{protocol}://{args.host}:{args.port}"
|
|
ASCIIColors.magenta("\n🌐 Server Access Information:")
|
|
ASCIIColors.white(" ├─ WebUI (local): ", end="")
|
|
ASCIIColors.yellow(f"{base_url}")
|
|
ASCIIColors.white(" ├─ API Documentation: ", end="")
|
|
ASCIIColors.yellow(f"{base_url}/docs")
|
|
ASCIIColors.white(" └─ Alternative Documentation: ", end="")
|
|
ASCIIColors.yellow(f"{base_url}/redoc")
|
|
|
|
# Security Notice
|
|
if args.key:
|
|
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
|
ASCIIColors.white(""" API Key authentication is enabled.
|
|
Make sure to include the X-API-Key header in all your requests.
|
|
""")
|
|
if args.auth_accounts:
|
|
ASCIIColors.yellow("\n⚠️ Security Notice:")
|
|
ASCIIColors.white(""" JWT authentication is enabled.
|
|
Make sure to login before making the request, and include the 'Authorization' in the header.
|
|
""")
|
|
|
|
# Ensure splash output flush to system log
|
|
sys.stdout.flush()
|