| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | from fastapi import FastAPI, HTTPException, File, UploadFile, Form | 
					
						
							|  |  |  | from pydantic import BaseModel | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | from lightrag import LightRAG, QueryParam | 
					
						
							| 
									
										
										
										
											2025-01-15 10:44:12 +08:00
										 |  |  | # from lightrag.llm import lollms_model_complete, lollms_embed | 
					
						
							|  |  |  | # from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding | 
					
						
							|  |  |  | from lightrag.llm import openai_complete_if_cache, ollama_embedding | 
					
						
							|  |  |  | # from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | from lightrag.utils import EmbeddingFunc | 
					
						
							|  |  |  | from typing import Optional, List | 
					
						
							|  |  |  | from enum import Enum | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | import shutil | 
					
						
							|  |  |  | import aiofiles | 
					
						
							|  |  |  | from ascii_colors import trace_exception | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | from fastapi import Depends, Security | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  | from fastapi.security import APIKeyHeader | 
					
						
							|  |  |  | from fastapi.middleware.cors import CORSMiddleware | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from starlette.status import HTTP_403_FORBIDDEN | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-15 10:44:12 +08:00
										 |  |  | from dotenv import load_dotenv | 
					
						
							|  |  |  | load_dotenv() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def llm_model_func( | 
					
						
							|  |  |  |     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | 
					
						
							|  |  |  | ) -> str: | 
					
						
							|  |  |  |     return await openai_complete_if_cache( | 
					
						
							|  |  |  |         "deepseek-chat", | 
					
						
							|  |  |  |         prompt, | 
					
						
							|  |  |  |         system_prompt=system_prompt, | 
					
						
							|  |  |  |         history_messages=history_messages, | 
					
						
							|  |  |  |         api_key=os.getenv("DEEPSEEK_API_KEY"), | 
					
						
							|  |  |  |         base_url=os.getenv("DEEPSEEK_ENDPOINT"), | 
					
						
							|  |  |  |         **kwargs, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  | def get_default_host(binding_type: str) -> str: | 
					
						
							|  |  |  |     default_hosts = { | 
					
						
							| 
									
										
										
										
											2025-01-15 10:44:12 +08:00
										 |  |  |         "ollama": "http://m4.lan.znipower.com:11434", | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |         "lollms": "http://localhost:9600", | 
					
						
							|  |  |  |         "azure_openai": "https://api.openai.com/v1", | 
					
						
							| 
									
										
										
										
											2025-01-15 10:44:12 +08:00
										 |  |  |         "openai": os.getenv("DEEPSEEK_ENDPOINT"), | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |     return default_hosts.get( | 
					
						
							|  |  |  |         binding_type, "http://localhost:11434" | 
					
						
							|  |  |  |     )  # fallback to ollama if unknown | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def parse_args(): | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser( | 
					
						
							|  |  |  |         description="LightRAG FastAPI Server with separate working and input directories" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |     # Start by the bindings | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--llm-binding", | 
					
						
							|  |  |  |         default="ollama", | 
					
						
							|  |  |  |         help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--embedding-binding", | 
					
						
							|  |  |  |         default="ollama", | 
					
						
							|  |  |  |         help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |     # Parse just these arguments first | 
					
						
							|  |  |  |     temp_args, _ = parser.parse_known_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Add remaining arguments with dynamic defaults for hosts | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     # Server configuration | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--port", type=int, default=9621, help="Server port (default: 9621)" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Directory configuration | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--working-dir", | 
					
						
							|  |  |  |         default="./rag_storage", | 
					
						
							|  |  |  |         help="Working directory for RAG storage (default: ./rag_storage)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--input-dir", | 
					
						
							|  |  |  |         default="./inputs", | 
					
						
							|  |  |  |         help="Directory containing input documents (default: ./inputs)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |     # LLM Model configuration | 
					
						
							|  |  |  |     default_llm_host = get_default_host(temp_args.llm_binding) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |         "--llm-binding-host", | 
					
						
							|  |  |  |         default=default_llm_host, | 
					
						
							|  |  |  |         help=f"llm server host URL (default: {default_llm_host})", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--llm-model", | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |         default="mistral-nemo:latest", | 
					
						
							|  |  |  |         help="LLM model name (default: mistral-nemo:latest)", | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Embedding model configuration | 
					
						
							|  |  |  |     default_embedding_host = get_default_host(temp_args.embedding_binding) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--embedding-binding-host", | 
					
						
							|  |  |  |         default=default_embedding_host, | 
					
						
							|  |  |  |         help=f"embedding server host URL (default: {default_embedding_host})", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--embedding-model", | 
					
						
							|  |  |  |         default="bge-m3:latest", | 
					
						
							|  |  |  |         help="Embedding model name (default: bge-m3:latest)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 22:17:13 +01:00
										 |  |  |     def timeout_type(value): | 
					
						
							|  |  |  |         if value is None or value == "None": | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         return int(value) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 21:39:25 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--timeout", | 
					
						
							| 
									
										
										
										
											2025-01-10 22:17:13 +01:00
										 |  |  |         default=None, | 
					
						
							|  |  |  |         type=timeout_type, | 
					
						
							|  |  |  |         help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", | 
					
						
							| 
									
										
										
										
											2025-01-10 21:39:25 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     # RAG configuration | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--max-tokens", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=32768, | 
					
						
							|  |  |  |         help="Maximum token size (default: 32768)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--embedding-dim", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=1024, | 
					
						
							|  |  |  |         help="Embedding dimensions (default: 1024)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--max-embed-tokens", | 
					
						
							|  |  |  |         type=int, | 
					
						
							|  |  |  |         default=8192, | 
					
						
							|  |  |  |         help="Maximum embedding token size (default: 8192)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Logging configuration | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--log-level", | 
					
						
							|  |  |  |         default="INFO", | 
					
						
							|  |  |  |         choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], | 
					
						
							|  |  |  |         help="Logging level (default: INFO)", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--key", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         help="API key for authentication. This protects lightrag server against unauthorized access", | 
					
						
							|  |  |  |         default=None, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 21:39:25 +01:00
										 |  |  |     # Optional https parameters | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |         "--ssl", action="store_true", help="Enable HTTPS (default: False)" | 
					
						
							| 
									
										
										
										
											2025-01-10 21:39:25 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--ssl-certfile", | 
					
						
							|  |  |  |         default=None, | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |         help="Path to SSL certificate file (required if --ssl is enabled)", | 
					
						
							| 
									
										
										
										
											2025-01-10 21:39:25 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--ssl-keyfile", | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |         default=None, | 
					
						
							|  |  |  |         help="Path to SSL private key file (required if --ssl is enabled)", | 
					
						
							| 
									
										
										
										
											2025-01-10 21:39:25 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     return parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DocumentManager: | 
					
						
							|  |  |  |     """Handles document operations and tracking""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): | 
					
						
							|  |  |  |         self.input_dir = Path(input_dir) | 
					
						
							|  |  |  |         self.supported_extensions = supported_extensions | 
					
						
							|  |  |  |         self.indexed_files = set() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Create input directory if it doesn't exist | 
					
						
							|  |  |  |         self.input_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def scan_directory(self) -> List[Path]: | 
					
						
							|  |  |  |         """Scan input directory for new files""" | 
					
						
							|  |  |  |         new_files = [] | 
					
						
							|  |  |  |         for ext in self.supported_extensions: | 
					
						
							|  |  |  |             for file_path in self.input_dir.rglob(f"*{ext}"): | 
					
						
							|  |  |  |                 if file_path not in self.indexed_files: | 
					
						
							|  |  |  |                     new_files.append(file_path) | 
					
						
							|  |  |  |         return new_files | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def mark_as_indexed(self, file_path: Path): | 
					
						
							|  |  |  |         """Mark a file as indexed""" | 
					
						
							|  |  |  |         self.indexed_files.add(file_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def is_supported_file(self, filename: str) -> bool: | 
					
						
							|  |  |  |         """Check if file type is supported""" | 
					
						
							|  |  |  |         return any(filename.lower().endswith(ext) for ext in self.supported_extensions) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Pydantic models | 
					
						
							|  |  |  | class SearchMode(str, Enum): | 
					
						
							|  |  |  |     naive = "naive" | 
					
						
							|  |  |  |     local = "local" | 
					
						
							|  |  |  |     global_ = "global" | 
					
						
							|  |  |  |     hybrid = "hybrid" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QueryRequest(BaseModel): | 
					
						
							|  |  |  |     query: str | 
					
						
							|  |  |  |     mode: SearchMode = SearchMode.hybrid | 
					
						
							|  |  |  |     stream: bool = False | 
					
						
							| 
									
										
										
										
											2024-12-26 23:32:02 +01:00
										 |  |  |     only_need_context: bool = False | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QueryResponse(BaseModel): | 
					
						
							|  |  |  |     response: str | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class InsertTextRequest(BaseModel): | 
					
						
							|  |  |  |     text: str | 
					
						
							|  |  |  |     description: Optional[str] = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class InsertResponse(BaseModel): | 
					
						
							|  |  |  |     status: str | 
					
						
							|  |  |  |     message: str | 
					
						
							|  |  |  |     document_count: int | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  | def get_api_key_dependency(api_key: Optional[str]): | 
					
						
							|  |  |  |     if not api_key: | 
					
						
							|  |  |  |         # If no API key is configured, return a dummy dependency that always succeeds | 
					
						
							|  |  |  |         async def no_auth(): | 
					
						
							|  |  |  |             return None | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |         return no_auth | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     # If API key is configured, use proper authentication | 
					
						
							|  |  |  |     api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): | 
					
						
							|  |  |  |         if not api_key_header_value: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |                 status_code=HTTP_403_FORBIDDEN, detail="API Key required" | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  |         if api_key_header_value != api_key: | 
					
						
							|  |  |  |             raise HTTPException( | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |                 status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  |         return api_key_header_value | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     return api_key_auth | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def create_app(args): | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |     # Verify that bindings arer correctly setup | 
					
						
							|  |  |  |     if args.llm_binding not in ["lollms", "ollama", "openai"]: | 
					
						
							|  |  |  |         raise Exception("llm binding not supported") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if args.embedding_binding not in ["lollms", "ollama", "openai"]: | 
					
						
							|  |  |  |         raise Exception("embedding binding not supported") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-11 01:35:49 +01:00
										 |  |  |     # Add SSL validation | 
					
						
							|  |  |  |     if args.ssl: | 
					
						
							|  |  |  |         if not args.ssl_certfile or not args.ssl_keyfile: | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |             raise Exception( | 
					
						
							|  |  |  |                 "SSL certificate and key files must be provided when SSL is enabled" | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2025-01-11 01:35:49 +01:00
										 |  |  |         if not os.path.exists(args.ssl_certfile): | 
					
						
							|  |  |  |             raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") | 
					
						
							|  |  |  |         if not os.path.exists(args.ssl_keyfile): | 
					
						
							|  |  |  |             raise Exception(f"SSL key file not found: {args.ssl_keyfile}") | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     # Setup logging | 
					
						
							|  |  |  |     logging.basicConfig( | 
					
						
							|  |  |  |         format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     # Check if API key is provided either through env var or args | 
					
						
							|  |  |  |     api_key = os.getenv("LIGHTRAG_API_KEY") or args.key | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     # Initialize FastAPI | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     app = FastAPI( | 
					
						
							|  |  |  |         title="LightRAG API", | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |         description="API for querying text using LightRAG with separate storage and input directories" | 
					
						
							|  |  |  |         + "(With authentication)" | 
					
						
							|  |  |  |         if api_key | 
					
						
							|  |  |  |         else "", | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |         version="1.0.1", | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |         openapi_tags=[{"name": "api"}], | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     # Add CORS middleware | 
					
						
							|  |  |  |     app.add_middleware( | 
					
						
							|  |  |  |         CORSMiddleware, | 
					
						
							|  |  |  |         allow_origins=["*"], | 
					
						
							|  |  |  |         allow_credentials=True, | 
					
						
							|  |  |  |         allow_methods=["*"], | 
					
						
							|  |  |  |         allow_headers=["*"], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Create the optional API key dependency | 
					
						
							|  |  |  |     optional_api_key = get_api_key_dependency(api_key) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Create working directory if it doesn't exist | 
					
						
							|  |  |  |     Path(args.working_dir).mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Initialize document manager | 
					
						
							|  |  |  |     doc_manager = DocumentManager(args.input_dir) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Initialize RAG | 
					
						
							|  |  |  |     rag = LightRAG( | 
					
						
							|  |  |  |         working_dir=args.working_dir, | 
					
						
							| 
									
										
										
										
											2025-01-15 10:44:12 +08:00
										 |  |  |         llm_model_func=llm_model_func, | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |         embedding_func=EmbeddingFunc( | 
					
						
							| 
									
										
										
										
											2025-01-15 10:44:12 +08:00
										 |  |  |             embedding_dim=1024, | 
					
						
							|  |  |  |             max_token_size=8192, | 
					
						
							|  |  |  |             func=lambda texts: ollama_embedding( | 
					
						
							|  |  |  |                 texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |             ), | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |         ), | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @app.on_event("startup") | 
					
						
							|  |  |  |     async def startup_event(): | 
					
						
							|  |  |  |         """Index all files in input directory during startup""" | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             new_files = doc_manager.scan_directory() | 
					
						
							|  |  |  |             for file_path in new_files: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     # Use async file reading | 
					
						
							|  |  |  |                     async with aiofiles.open(file_path, "r", encoding="utf-8") as f: | 
					
						
							|  |  |  |                         content = await f.read() | 
					
						
							|  |  |  |                         # Use the async version of insert directly | 
					
						
							|  |  |  |                         await rag.ainsert(content) | 
					
						
							|  |  |  |                         doc_manager.mark_as_indexed(file_path) | 
					
						
							|  |  |  |                         logging.info(f"Indexed file: {file_path}") | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     trace_exception(e) | 
					
						
							|  |  |  |                     logging.error(f"Error indexing file {file_path}: {str(e)}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logging.error(f"Error during startup indexing: {str(e)}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def scan_for_new_documents(): | 
					
						
							|  |  |  |         """Manually trigger scanning for new documents""" | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             new_files = doc_manager.scan_directory() | 
					
						
							|  |  |  |             indexed_count = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for file_path in new_files: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     with open(file_path, "r", encoding="utf-8") as f: | 
					
						
							|  |  |  |                         content = f.read() | 
					
						
							| 
									
										
										
										
											2024-12-26 22:48:52 +01:00
										 |  |  |                         await rag.ainsert(content) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |                         doc_manager.mark_as_indexed(file_path) | 
					
						
							|  |  |  |                         indexed_count += 1 | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     logging.error(f"Error indexing file {file_path}: {str(e)}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return { | 
					
						
							|  |  |  |                 "status": "success", | 
					
						
							|  |  |  |                 "indexed_count": indexed_count, | 
					
						
							|  |  |  |                 "total_documents": len(doc_manager.indexed_files), | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def upload_to_input_dir(file: UploadFile = File(...)): | 
					
						
							|  |  |  |         """Upload a file to the input directory""" | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             if not doc_manager.is_supported_file(file.filename): | 
					
						
							|  |  |  |                 raise HTTPException( | 
					
						
							|  |  |  |                     status_code=400, | 
					
						
							|  |  |  |                     detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             file_path = doc_manager.input_dir / file.filename | 
					
						
							|  |  |  |             with open(file_path, "wb") as buffer: | 
					
						
							|  |  |  |                 shutil.copyfileobj(file.file, buffer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Immediately index the uploaded file | 
					
						
							|  |  |  |             with open(file_path, "r", encoding="utf-8") as f: | 
					
						
							|  |  |  |                 content = f.read() | 
					
						
							| 
									
										
										
										
											2024-12-26 22:48:52 +01:00
										 |  |  |                 await rag.ainsert(content) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |                 doc_manager.mark_as_indexed(file_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return { | 
					
						
							|  |  |  |                 "status": "success", | 
					
						
							|  |  |  |                 "message": f"File uploaded and indexed: {file.filename}", | 
					
						
							|  |  |  |                 "total_documents": len(doc_manager.indexed_files), | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |     @app.post( | 
					
						
							|  |  |  |         "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def query_text(request: QueryRequest): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = await rag.aquery( | 
					
						
							|  |  |  |                 request.query, | 
					
						
							| 
									
										
										
										
											2024-12-26 23:39:10 +01:00
										 |  |  |                 param=QueryParam( | 
					
						
							|  |  |  |                     mode=request.mode, | 
					
						
							|  |  |  |                     stream=request.stream, | 
					
						
							|  |  |  |                     only_need_context=request.only_need_context, | 
					
						
							|  |  |  |                 ), | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if request.stream: | 
					
						
							|  |  |  |                 result = "" | 
					
						
							|  |  |  |                 async for chunk in response: | 
					
						
							|  |  |  |                     result += chunk | 
					
						
							|  |  |  |                 return QueryResponse(response=result) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return QueryResponse(response=response) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def query_text_stream(request: QueryRequest): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = rag.query( | 
					
						
							| 
									
										
										
										
											2024-12-26 23:39:10 +01:00
										 |  |  |                 request.query, | 
					
						
							|  |  |  |                 param=QueryParam( | 
					
						
							|  |  |  |                     mode=request.mode, | 
					
						
							|  |  |  |                     stream=True, | 
					
						
							|  |  |  |                     only_need_context=request.only_need_context, | 
					
						
							|  |  |  |                 ), | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             async def stream_generator(): | 
					
						
							|  |  |  |                 async for chunk in response: | 
					
						
							|  |  |  |                     yield chunk | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return stream_generator() | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |     @app.post( | 
					
						
							|  |  |  |         "/documents/text", | 
					
						
							|  |  |  |         response_model=InsertResponse, | 
					
						
							|  |  |  |         dependencies=[Depends(optional_api_key)], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def insert_text(request: InsertTextRequest): | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2025-01-12 12:56:08 +01:00
										 |  |  |             await rag.ainsert(request.text) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             return InsertResponse( | 
					
						
							|  |  |  |                 status="success", | 
					
						
							|  |  |  |                 message="Text successfully inserted", | 
					
						
							| 
									
										
										
										
											2025-01-12 12:46:23 +01:00
										 |  |  |                 document_count=1, | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |     @app.post( | 
					
						
							|  |  |  |         "/documents/file", | 
					
						
							|  |  |  |         response_model=InsertResponse, | 
					
						
							|  |  |  |         dependencies=[Depends(optional_api_key)], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def insert_file(file: UploadFile = File(...), description: str = Form(None)): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             content = await file.read() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if file.filename.endswith((".txt", ".md")): | 
					
						
							|  |  |  |                 text = content.decode("utf-8") | 
					
						
							| 
									
										
										
										
											2024-12-26 22:48:52 +01:00
										 |  |  |                 await rag.ainsert(text) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 raise HTTPException( | 
					
						
							|  |  |  |                     status_code=400, | 
					
						
							|  |  |  |                     detail="Unsupported file type. Only .txt and .md files are supported", | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return InsertResponse( | 
					
						
							|  |  |  |                 status="success", | 
					
						
							|  |  |  |                 message=f"File '{file.filename}' successfully inserted", | 
					
						
							| 
									
										
										
										
											2024-12-26 22:48:52 +01:00
										 |  |  |                 document_count=1, | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  |         except UnicodeDecodeError: | 
					
						
							|  |  |  |             raise HTTPException(status_code=400, detail="File encoding not supported") | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |     @app.post( | 
					
						
							|  |  |  |         "/documents/batch", | 
					
						
							|  |  |  |         response_model=InsertResponse, | 
					
						
							|  |  |  |         dependencies=[Depends(optional_api_key)], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def insert_batch(files: List[UploadFile] = File(...)): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             inserted_count = 0 | 
					
						
							|  |  |  |             failed_files = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for file in files: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     content = await file.read() | 
					
						
							|  |  |  |                     if file.filename.endswith((".txt", ".md")): | 
					
						
							|  |  |  |                         text = content.decode("utf-8") | 
					
						
							| 
									
										
										
										
											2024-12-26 22:48:52 +01:00
										 |  |  |                         await rag.ainsert(text) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |                         inserted_count += 1 | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         failed_files.append(f"{file.filename} (unsupported type)") | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     failed_files.append(f"{file.filename} ({str(e)})") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             status_message = f"Successfully inserted {inserted_count} documents" | 
					
						
							|  |  |  |             if failed_files: | 
					
						
							|  |  |  |                 status_message += f". Failed files: {', '.join(failed_files)}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return InsertResponse( | 
					
						
							|  |  |  |                 status="success" if inserted_count > 0 else "partial_success", | 
					
						
							|  |  |  |                 message=status_message, | 
					
						
							| 
									
										
										
										
											2024-12-26 22:48:52 +01:00
										 |  |  |                 document_count=len(files), | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:23:39 +01:00
										 |  |  |     @app.delete( | 
					
						
							|  |  |  |         "/documents", | 
					
						
							|  |  |  |         response_model=InsertResponse, | 
					
						
							|  |  |  |         dependencies=[Depends(optional_api_key)], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def clear_documents(): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             rag.text_chunks = [] | 
					
						
							|  |  |  |             rag.entities_vdb = None | 
					
						
							|  |  |  |             rag.relationships_vdb = None | 
					
						
							|  |  |  |             return InsertResponse( | 
					
						
							|  |  |  |                 status="success", | 
					
						
							|  |  |  |                 message="All documents cleared successfully", | 
					
						
							|  |  |  |                 document_count=0, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             raise HTTPException(status_code=500, detail=str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-04 02:21:37 +01:00
										 |  |  |     @app.get("/health", dependencies=[Depends(optional_api_key)]) | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     async def get_status(): | 
					
						
							|  |  |  |         """Get current system status""" | 
					
						
							|  |  |  |         return { | 
					
						
							|  |  |  |             "status": "healthy", | 
					
						
							|  |  |  |             "working_directory": str(args.working_dir), | 
					
						
							|  |  |  |             "input_directory": str(args.input_dir), | 
					
						
							|  |  |  |             "indexed_files": len(doc_manager.indexed_files), | 
					
						
							|  |  |  |             "configuration": { | 
					
						
							| 
									
										
										
										
											2025-01-10 20:30:58 +01:00
										 |  |  |                 # LLM configuration binding/host address (if applicable)/model (if applicable) | 
					
						
							|  |  |  |                 "llm_binding": args.llm_binding, | 
					
						
							|  |  |  |                 "llm_binding_host": args.llm_binding_host, | 
					
						
							|  |  |  |                 "llm_model": args.llm_model, | 
					
						
							|  |  |  |                 # embedding model configuration binding/host address (if applicable)/model (if applicable) | 
					
						
							|  |  |  |                 "embedding_binding": args.embedding_binding, | 
					
						
							|  |  |  |                 "embedding_binding_host": args.embedding_binding_host, | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |                 "embedding_model": args.embedding_model, | 
					
						
							|  |  |  |                 "max_tokens": args.max_tokens, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return app | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-24 10:18:41 +01:00
										 |  |  | def main(): | 
					
						
							| 
									
										
										
										
											2024-12-22 00:38:38 +01:00
										 |  |  |     args = parse_args() | 
					
						
							|  |  |  |     import uvicorn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     app = create_app(args) | 
					
						
							| 
									
										
										
										
											2025-01-11 01:35:49 +01:00
										 |  |  |     uvicorn_config = { | 
					
						
							|  |  |  |         "app": app, | 
					
						
							|  |  |  |         "host": args.host, | 
					
						
							|  |  |  |         "port": args.port, | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2025-01-11 01:35:49 +01:00
										 |  |  |     if args.ssl: | 
					
						
							| 
									
										
										
										
											2025-01-11 01:37:07 +01:00
										 |  |  |         uvicorn_config.update( | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "ssl_certfile": args.ssl_certfile, | 
					
						
							|  |  |  |                 "ssl_keyfile": args.ssl_keyfile, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2025-01-11 01:35:49 +01:00
										 |  |  |     uvicorn.run(**uvicorn_config) | 
					
						
							| 
									
										
										
										
											2024-12-24 10:18:41 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-24 10:35:00 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-24 10:18:41 +01:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     main() |