2025-02-15 22:37:12 +01:00
|
|
|
|
from __future__ import annotations
|
2025-04-28 22:52:31 +08:00
|
|
|
|
import weakref
|
2025-02-15 22:37:12 +01:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
import html
|
2024-10-31 14:31:26 +08:00
|
|
|
|
import csv
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import json
|
|
|
|
|
import logging
|
2025-03-03 23:18:41 +08:00
|
|
|
|
import logging.handlers
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from functools import wraps
|
|
|
|
|
from hashlib import md5
|
2025-04-18 16:14:31 +02:00
|
|
|
|
from typing import Any, Protocol, Callable, TYPE_CHECKING, List
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import numpy as np
|
2024-12-08 17:35:52 +08:00
|
|
|
|
from lightrag.prompt import PROMPTS
|
2025-02-22 13:25:12 +08:00
|
|
|
|
from dotenv import load_dotenv
|
2025-05-06 22:00:43 +08:00
|
|
|
|
from lightrag.constants import (
|
|
|
|
|
DEFAULT_LOG_MAX_BYTES,
|
|
|
|
|
DEFAULT_LOG_BACKUP_COUNT,
|
|
|
|
|
DEFAULT_LOG_FILENAME,
|
|
|
|
|
)
|
2025-05-10 08:58:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_env_value(
|
|
|
|
|
env_key: str, default: any, value_type: type = str, special_none: bool = False
|
|
|
|
|
) -> any:
|
|
|
|
|
"""
|
|
|
|
|
Get value from environment variable with type conversion
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
env_key (str): Environment variable key
|
|
|
|
|
default (any): Default value if env variable is not set
|
|
|
|
|
value_type (type): Type to convert the value to
|
|
|
|
|
special_none (bool): If True, return None when value is "None"
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
any: Converted value from environment or default
|
|
|
|
|
"""
|
|
|
|
|
value = os.getenv(env_key)
|
|
|
|
|
if value is None:
|
|
|
|
|
return default
|
|
|
|
|
|
|
|
|
|
# Handle special case for "None" string
|
|
|
|
|
if special_none and value == "None":
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if value_type is bool:
|
|
|
|
|
return value.lower() in ("true", "1", "yes", "t", "on")
|
|
|
|
|
try:
|
|
|
|
|
return value_type(value)
|
|
|
|
|
except (ValueError, TypeError):
|
|
|
|
|
return default
|
|
|
|
|
|
2025-02-22 13:25:12 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
# Use TYPE_CHECKING to avoid circular imports
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from lightrag.base import BaseKVStorage
|
|
|
|
|
|
2025-03-29 03:37:23 +08:00
|
|
|
|
# use the .env that is inside the current folder
|
|
|
|
|
# allows to use different .env file for each lightrag instance
|
|
|
|
|
# the OS environment variables take precedence over the .env file
|
|
|
|
|
load_dotenv(dotenv_path=".env", override=False)
|
2025-02-18 17:11:17 +01:00
|
|
|
|
|
2025-02-17 11:37:38 +08:00
|
|
|
|
VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true"
|
2025-02-17 01:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verbose_debug(msg: str, *args, **kwargs):
|
|
|
|
|
"""Function for outputting detailed debug information.
|
|
|
|
|
When VERBOSE_DEBUG=True, outputs the complete message.
|
2025-02-21 16:28:08 +08:00
|
|
|
|
When VERBOSE_DEBUG=False, outputs only the first 50 characters.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
msg: The message format string
|
|
|
|
|
*args: Arguments to be formatted into the message
|
|
|
|
|
**kwargs: Keyword arguments passed to logger.debug()
|
2025-02-17 01:38:18 +08:00
|
|
|
|
"""
|
|
|
|
|
if VERBOSE_DEBUG:
|
|
|
|
|
logger.debug(msg, *args, **kwargs)
|
2025-02-21 16:28:08 +08:00
|
|
|
|
else:
|
|
|
|
|
# Format the message with args first
|
|
|
|
|
if args:
|
|
|
|
|
formatted_msg = msg % args
|
|
|
|
|
else:
|
|
|
|
|
formatted_msg = msg
|
|
|
|
|
# Then truncate the formatted message
|
|
|
|
|
truncated_msg = (
|
2025-03-28 21:39:04 +08:00
|
|
|
|
formatted_msg[:100] + "..." if len(formatted_msg) > 100 else formatted_msg
|
2025-02-21 16:28:08 +08:00
|
|
|
|
)
|
|
|
|
|
logger.debug(truncated_msg, **kwargs)
|
2025-02-17 01:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_verbose_debug(enabled: bool):
|
|
|
|
|
"""Enable or disable verbose debug output"""
|
|
|
|
|
global VERBOSE_DEBUG
|
|
|
|
|
VERBOSE_DEBUG = enabled
|
|
|
|
|
|
2025-02-26 18:11:16 +08:00
|
|
|
|
|
2025-01-16 12:52:37 +08:00
|
|
|
|
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
|
|
|
|
|
2025-02-26 12:23:35 +08:00
|
|
|
|
# Initialize logger
|
2024-10-10 15:02:30 +08:00
|
|
|
|
logger = logging.getLogger("lightrag")
|
2025-02-26 12:23:35 +08:00
|
|
|
|
logger.propagate = False # prevent log message send to root loggger
|
|
|
|
|
# Let the main application configure the handlers
|
|
|
|
|
logger.setLevel(logging.INFO)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2025-01-16 12:52:37 +08:00
|
|
|
|
# Set httpx logging level to WARNING
|
|
|
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-03-03 23:18:41 +08:00
|
|
|
|
class LightragPathFilter(logging.Filter):
|
|
|
|
|
"""Filter for lightrag logger to filter out frequent path access logs"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
# Define paths to be filtered
|
2025-03-26 23:53:41 +08:00
|
|
|
|
self.filtered_paths = [
|
|
|
|
|
"/documents",
|
|
|
|
|
"/health",
|
|
|
|
|
"/webui/",
|
|
|
|
|
"/documents/pipeline_status",
|
|
|
|
|
]
|
2025-03-13 17:45:56 +08:00
|
|
|
|
# self.filtered_paths = ["/health", "/webui/"]
|
2025-03-03 23:18:41 +08:00
|
|
|
|
|
|
|
|
|
def filter(self, record):
|
|
|
|
|
try:
|
|
|
|
|
# Check if record has the required attributes for an access log
|
|
|
|
|
if not hasattr(record, "args") or not isinstance(record.args, tuple):
|
|
|
|
|
return True
|
|
|
|
|
if len(record.args) < 5:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# Extract method, path and status from the record args
|
|
|
|
|
method = record.args[1]
|
|
|
|
|
path = record.args[2]
|
|
|
|
|
status = record.args[4]
|
|
|
|
|
|
|
|
|
|
# Filter out successful GET requests to filtered paths
|
|
|
|
|
if (
|
|
|
|
|
method == "GET"
|
|
|
|
|
and (status == 200 or status == 304)
|
|
|
|
|
and path in self.filtered_paths
|
|
|
|
|
):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
except Exception:
|
|
|
|
|
# In case of any error, let the message through
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_logger(
|
|
|
|
|
logger_name: str,
|
|
|
|
|
level: str = "INFO",
|
|
|
|
|
add_filter: bool = False,
|
2025-03-18 12:08:42 +01:00
|
|
|
|
log_file_path: str | None = None,
|
|
|
|
|
enable_file_logging: bool = True,
|
2025-03-03 23:18:41 +08:00
|
|
|
|
):
|
2025-03-18 12:08:42 +01:00
|
|
|
|
"""Set up a logger with console and optionally file handlers
|
2025-03-03 23:18:41 +08:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logger_name: Name of the logger to set up
|
|
|
|
|
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
|
|
|
add_filter: Whether to add LightragPathFilter to the logger
|
2025-03-18 12:08:42 +01:00
|
|
|
|
log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
|
|
|
|
|
enable_file_logging: Whether to enable logging to a file (defaults to True)
|
2025-03-03 23:18:41 +08:00
|
|
|
|
"""
|
|
|
|
|
# Configure formatters
|
|
|
|
|
detailed_formatter = logging.Formatter(
|
|
|
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
|
)
|
|
|
|
|
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
|
|
|
|
|
|
|
|
|
logger_instance = logging.getLogger(logger_name)
|
|
|
|
|
logger_instance.setLevel(level)
|
|
|
|
|
logger_instance.handlers = [] # Clear existing handlers
|
|
|
|
|
logger_instance.propagate = False
|
|
|
|
|
|
|
|
|
|
# Add console handler
|
|
|
|
|
console_handler = logging.StreamHandler()
|
|
|
|
|
console_handler.setFormatter(simple_formatter)
|
|
|
|
|
console_handler.setLevel(level)
|
|
|
|
|
logger_instance.addHandler(console_handler)
|
|
|
|
|
|
2025-03-18 12:08:42 +01:00
|
|
|
|
# Add file handler by default unless explicitly disabled
|
|
|
|
|
if enable_file_logging:
|
|
|
|
|
# Get log file path
|
|
|
|
|
if log_file_path is None:
|
|
|
|
|
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
2025-05-06 22:00:43 +08:00
|
|
|
|
log_file_path = os.path.abspath(os.path.join(log_dir, DEFAULT_LOG_FILENAME))
|
2025-03-18 12:08:42 +01:00
|
|
|
|
|
|
|
|
|
# Ensure log directory exists
|
|
|
|
|
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Get log file max size and backup count from environment variables
|
2025-05-06 22:00:43 +08:00
|
|
|
|
log_max_bytes = get_env_value("LOG_MAX_BYTES", DEFAULT_LOG_MAX_BYTES, int)
|
2025-05-06 22:03:40 +08:00
|
|
|
|
log_backup_count = get_env_value(
|
|
|
|
|
"LOG_BACKUP_COUNT", DEFAULT_LOG_BACKUP_COUNT, int
|
|
|
|
|
)
|
2025-03-18 12:08:42 +01:00
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Add file handler
|
|
|
|
|
file_handler = logging.handlers.RotatingFileHandler(
|
|
|
|
|
filename=log_file_path,
|
|
|
|
|
maxBytes=log_max_bytes,
|
|
|
|
|
backupCount=log_backup_count,
|
|
|
|
|
encoding="utf-8",
|
|
|
|
|
)
|
|
|
|
|
file_handler.setFormatter(detailed_formatter)
|
|
|
|
|
file_handler.setLevel(level)
|
|
|
|
|
logger_instance.addHandler(file_handler)
|
|
|
|
|
except PermissionError as e:
|
|
|
|
|
logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
|
|
|
|
|
logger.warning("Continuing with console logging only")
|
2025-03-03 23:18:41 +08:00
|
|
|
|
|
|
|
|
|
# Add path filter if requested
|
|
|
|
|
if add_filter:
|
|
|
|
|
path_filter = LightragPathFilter()
|
|
|
|
|
logger_instance.addFilter(path_filter)
|
|
|
|
|
|
|
|
|
|
|
2025-02-26 12:23:35 +08:00
|
|
|
|
class UnlimitedSemaphore:
|
|
|
|
|
"""A context manager that allows unlimited access."""
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
class EmbeddingFunc:
|
|
|
|
|
embedding_dim: int
|
|
|
|
|
max_token_size: int
|
|
|
|
|
func: callable
|
2025-02-01 22:07:12 +08:00
|
|
|
|
# concurrent_limit: int = 16
|
2024-12-10 09:01:21 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
2025-02-01 09:53:11 +08:00
|
|
|
|
return await self.func(*args, **kwargs)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
|
def locate_json_string_body_from_string(content: str) -> str | None:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
"""Locate the JSON string body from a string"""
|
2024-11-25 13:29:55 +08:00
|
|
|
|
try:
|
|
|
|
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
|
|
if maybe_json_str is not None:
|
|
|
|
|
maybe_json_str = maybe_json_str.group(0)
|
|
|
|
|
maybe_json_str = maybe_json_str.replace("\\n", "")
|
|
|
|
|
maybe_json_str = maybe_json_str.replace("\n", "")
|
|
|
|
|
maybe_json_str = maybe_json_str.replace("'", '"')
|
2024-11-29 21:41:37 +08:00
|
|
|
|
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
2024-11-25 13:29:55 +08:00
|
|
|
|
return maybe_json_str
|
2024-11-25 13:40:38 +08:00
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# try:
|
|
|
|
|
# content = (
|
|
|
|
|
# content.replace(kw_prompt[:-1], "")
|
|
|
|
|
# .replace("user", "")
|
|
|
|
|
# .replace("model", "")
|
|
|
|
|
# .strip()
|
|
|
|
|
# )
|
2024-11-25 13:40:38 +08:00
|
|
|
|
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# json.loads(maybe_json_str)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return None
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-09 11:46:01 +01:00
|
|
|
|
def convert_response_to_json(response: str) -> dict[str, Any]:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
json_str = locate_json_string_body_from_string(response)
|
|
|
|
|
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
|
|
|
|
try:
|
|
|
|
|
data = json.loads(json_str)
|
|
|
|
|
return data
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
logger.error(f"Failed to parse JSON: {json_str}")
|
|
|
|
|
raise e from None
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
2025-01-24 18:59:24 +08:00
|
|
|
|
"""Compute a hash for the given arguments.
|
|
|
|
|
Args:
|
|
|
|
|
*args: Arguments to hash
|
2025-02-01 23:05:02 +08:00
|
|
|
|
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
2025-01-24 18:59:24 +08:00
|
|
|
|
Returns:
|
|
|
|
|
str: Hash string
|
|
|
|
|
"""
|
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
|
|
# Convert all arguments to strings and join them
|
|
|
|
|
args_str = "".join([str(arg) for arg in args])
|
|
|
|
|
if cache_type:
|
|
|
|
|
args_str = f"{cache_type}:{args_str}"
|
|
|
|
|
|
|
|
|
|
# Compute MD5 hash
|
|
|
|
|
return hashlib.md5(args_str.encode()).hexdigest()
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-14 23:31:27 +01:00
|
|
|
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
|
|
|
|
"""
|
|
|
|
|
Compute a unique ID for a given content string.
|
|
|
|
|
|
|
|
|
|
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
|
|
|
|
"""
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return prefix + md5(content.encode()).hexdigest()
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
# Custom exception class
|
|
|
|
|
class QueueFullError(Exception):
|
|
|
|
|
"""Raised when the queue is full and the wait times out"""
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
pass
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
2025-04-28 18:12:29 +08:00
|
|
|
|
"""
|
2025-04-28 22:52:31 +08:00
|
|
|
|
Enhanced priority-limited asynchronous function call decorator
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 18:12:29 +08:00
|
|
|
|
Args:
|
|
|
|
|
max_size: Maximum number of concurrent calls
|
2025-04-28 22:52:31 +08:00
|
|
|
|
max_queue_size: Maximum queue capacity to prevent memory overflow
|
2025-04-28 18:12:29 +08:00
|
|
|
|
Returns:
|
2025-04-28 22:52:31 +08:00
|
|
|
|
Decorator function
|
2025-04-28 18:12:29 +08:00
|
|
|
|
"""
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 18:12:29 +08:00
|
|
|
|
def final_decro(func):
|
2025-05-13 02:00:01 +08:00
|
|
|
|
# Ensure func is callable
|
|
|
|
|
if not callable(func):
|
|
|
|
|
raise TypeError(f"Expected a callable object, got {type(func)}")
|
2025-04-28 22:52:31 +08:00
|
|
|
|
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
|
2025-04-28 18:12:29 +08:00
|
|
|
|
tasks = set()
|
2025-04-29 00:08:52 +08:00
|
|
|
|
initialization_lock = asyncio.Lock()
|
2025-04-28 21:07:01 +08:00
|
|
|
|
counter = 0
|
2025-04-28 22:52:31 +08:00
|
|
|
|
shutdown_event = asyncio.Event()
|
2025-04-29 00:08:52 +08:00
|
|
|
|
initialized = False # Global initialization flag
|
2025-04-28 22:52:31 +08:00
|
|
|
|
worker_health_check_task = None
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
# Track active future objects for cleanup
|
|
|
|
|
active_futures = weakref.WeakSet()
|
2025-04-29 13:38:11 +08:00
|
|
|
|
reinit_count = 0 # Reinitialization counter to track system health
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
# Worker function to process tasks in the queue
|
2025-04-28 18:12:29 +08:00
|
|
|
|
async def worker():
|
2025-04-28 22:52:31 +08:00
|
|
|
|
"""Worker that processes tasks in the priority queue"""
|
|
|
|
|
try:
|
|
|
|
|
while not shutdown_event.is_set():
|
|
|
|
|
try:
|
|
|
|
|
# Use timeout to get tasks, allowing periodic checking of shutdown signal
|
|
|
|
|
try:
|
2025-04-28 23:21:34 +08:00
|
|
|
|
(
|
|
|
|
|
priority,
|
|
|
|
|
count,
|
|
|
|
|
future,
|
|
|
|
|
args,
|
|
|
|
|
kwargs,
|
|
|
|
|
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
2025-04-28 22:52:31 +08:00
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# Timeout is just to check shutdown signal, continue to next iteration
|
|
|
|
|
continue
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
# If future is cancelled, skip execution
|
|
|
|
|
if future.cancelled():
|
|
|
|
|
queue.task_done()
|
|
|
|
|
continue
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
try:
|
|
|
|
|
# Execute function
|
|
|
|
|
result = await func(*args, **kwargs)
|
|
|
|
|
# If future is not done, set the result
|
|
|
|
|
if not future.done():
|
|
|
|
|
future.set_result(result)
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
if not future.done():
|
|
|
|
|
future.cancel()
|
|
|
|
|
logger.debug("limit_async: Task cancelled during execution")
|
|
|
|
|
except Exception as e:
|
2025-04-28 23:21:34 +08:00
|
|
|
|
logger.error(
|
|
|
|
|
f"limit_async: Error in decorated function: {str(e)}"
|
|
|
|
|
)
|
2025-04-28 22:52:31 +08:00
|
|
|
|
if not future.done():
|
|
|
|
|
future.set_exception(e)
|
|
|
|
|
finally:
|
|
|
|
|
queue.task_done()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Catch all exceptions in worker loop to prevent worker termination
|
|
|
|
|
logger.error(f"limit_async: Critical error in worker: {str(e)}")
|
|
|
|
|
await asyncio.sleep(0.1) # Prevent high CPU usage
|
|
|
|
|
finally:
|
2025-04-30 13:53:03 +08:00
|
|
|
|
logger.debug("limit_async: Worker exiting")
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
async def health_check():
|
|
|
|
|
"""Periodically check worker health status and recover"""
|
2025-04-29 13:38:11 +08:00
|
|
|
|
nonlocal initialized
|
2025-04-28 22:52:31 +08:00
|
|
|
|
try:
|
|
|
|
|
while not shutdown_event.is_set():
|
|
|
|
|
await asyncio.sleep(5) # Check every 5 seconds
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-29 00:08:52 +08:00
|
|
|
|
# No longer acquire lock, directly operate on task set
|
|
|
|
|
# Use a copy of the task set to avoid concurrent modification
|
|
|
|
|
current_tasks = set(tasks)
|
|
|
|
|
done_tasks = {t for t in current_tasks if t.done()}
|
|
|
|
|
tasks.difference_update(done_tasks)
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-29 00:08:52 +08:00
|
|
|
|
# Calculate active tasks count
|
|
|
|
|
active_tasks_count = len(tasks)
|
|
|
|
|
workers_needed = max_size - active_tasks_count
|
|
|
|
|
|
|
|
|
|
if workers_needed > 0:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"limit_async: Creating {workers_needed} new workers"
|
|
|
|
|
)
|
|
|
|
|
new_tasks = set()
|
|
|
|
|
for _ in range(workers_needed):
|
|
|
|
|
task = asyncio.create_task(worker())
|
|
|
|
|
new_tasks.add(task)
|
|
|
|
|
task.add_done_callback(tasks.discard)
|
|
|
|
|
# Update task set in one operation
|
|
|
|
|
tasks.update(new_tasks)
|
2025-04-28 22:52:31 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"limit_async: Error in health check: {str(e)}")
|
|
|
|
|
finally:
|
2025-04-30 13:53:03 +08:00
|
|
|
|
logger.debug("limit_async: Health check task exiting")
|
2025-04-29 13:38:11 +08:00
|
|
|
|
initialized = False
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
async def ensure_workers():
|
2025-04-29 00:08:52 +08:00
|
|
|
|
"""Ensure worker threads and health check system are available
|
2025-04-28 22:52:31 +08:00
|
|
|
|
|
2025-04-29 00:08:52 +08:00
|
|
|
|
This function checks if the worker system is already initialized.
|
|
|
|
|
If not, it performs a one-time initialization of all worker threads
|
|
|
|
|
and starts the health check system.
|
|
|
|
|
"""
|
2025-04-29 13:38:11 +08:00
|
|
|
|
nonlocal initialized, worker_health_check_task, tasks, reinit_count
|
2025-04-28 22:52:31 +08:00
|
|
|
|
|
2025-04-29 00:08:52 +08:00
|
|
|
|
if initialized:
|
|
|
|
|
return
|
2025-04-28 22:52:31 +08:00
|
|
|
|
|
2025-04-29 00:08:52 +08:00
|
|
|
|
async with initialization_lock:
|
|
|
|
|
if initialized:
|
|
|
|
|
return
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-29 13:38:11 +08:00
|
|
|
|
# Increment reinitialization counter if this is not the first initialization
|
|
|
|
|
if reinit_count > 0:
|
|
|
|
|
reinit_count += 1
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"limit_async: Reinitializing needed (count: {reinit_count})"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
reinit_count = 1 # First initialization
|
|
|
|
|
|
|
|
|
|
# Check for completed tasks and remove them from the task set
|
|
|
|
|
current_tasks = set(tasks)
|
|
|
|
|
done_tasks = {t for t in current_tasks if t.done()}
|
|
|
|
|
tasks.difference_update(done_tasks)
|
|
|
|
|
|
|
|
|
|
# Log active tasks count during reinitialization
|
|
|
|
|
active_tasks_count = len(tasks)
|
|
|
|
|
if active_tasks_count > 0 and reinit_count > 1:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"limit_async: {active_tasks_count} tasks still running during reinitialization"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create initial worker tasks, only adding the number needed
|
|
|
|
|
workers_needed = max_size - active_tasks_count
|
|
|
|
|
for _ in range(workers_needed):
|
2025-04-28 22:52:31 +08:00
|
|
|
|
task = asyncio.create_task(worker())
|
|
|
|
|
tasks.add(task)
|
|
|
|
|
task.add_done_callback(tasks.discard)
|
|
|
|
|
|
2025-04-29 00:08:52 +08:00
|
|
|
|
# Start health check
|
|
|
|
|
worker_health_check_task = asyncio.create_task(health_check())
|
|
|
|
|
|
|
|
|
|
initialized = True
|
2025-04-29 13:38:11 +08:00
|
|
|
|
logger.info(f"limit_async: {workers_needed} new workers initialized")
|
2025-04-29 00:08:52 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
async def shutdown():
|
|
|
|
|
"""Gracefully shut down all workers and the queue"""
|
|
|
|
|
logger.info("limit_async: Shutting down priority queue workers")
|
|
|
|
|
|
|
|
|
|
# Set the shutdown event
|
|
|
|
|
shutdown_event.set()
|
|
|
|
|
|
|
|
|
|
# Cancel all active futures
|
|
|
|
|
for future in list(active_futures):
|
|
|
|
|
if not future.done():
|
|
|
|
|
future.cancel()
|
|
|
|
|
|
|
|
|
|
# Wait for the queue to empty
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.wait_for(queue.join(), timeout=5.0)
|
|
|
|
|
except asyncio.TimeoutError:
|
2025-04-28 23:21:34 +08:00
|
|
|
|
logger.warning(
|
|
|
|
|
"limit_async: Timeout waiting for queue to empty during shutdown"
|
|
|
|
|
)
|
2025-04-28 22:52:31 +08:00
|
|
|
|
|
|
|
|
|
# Cancel all worker tasks
|
|
|
|
|
for task in list(tasks):
|
|
|
|
|
if not task.done():
|
|
|
|
|
task.cancel()
|
|
|
|
|
|
|
|
|
|
# Wait for all tasks to complete
|
|
|
|
|
if tasks:
|
|
|
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
|
|
|
|
# Cancel the health check task
|
|
|
|
|
if worker_health_check_task and not worker_health_check_task.done():
|
|
|
|
|
worker_health_check_task.cancel()
|
|
|
|
|
try:
|
|
|
|
|
await worker_health_check_task
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
2025-04-28 18:12:29 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
logger.info("limit_async: Priority queue workers shutdown complete")
|
2025-04-28 18:12:29 +08:00
|
|
|
|
|
|
|
|
|
@wraps(func)
|
2025-04-28 23:21:34 +08:00
|
|
|
|
async def wait_func(
|
|
|
|
|
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
|
|
|
|
|
):
|
2025-04-28 18:12:29 +08:00
|
|
|
|
"""
|
2025-04-28 22:52:31 +08:00
|
|
|
|
Execute the function with priority-based concurrency control
|
2025-04-28 18:12:29 +08:00
|
|
|
|
Args:
|
2025-04-28 22:52:31 +08:00
|
|
|
|
*args: Positional arguments passed to the function
|
|
|
|
|
_priority: Call priority (lower values have higher priority)
|
|
|
|
|
_timeout: Maximum time to wait for function completion (in seconds)
|
|
|
|
|
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
|
|
|
|
|
**kwargs: Keyword arguments passed to the function
|
2025-04-28 18:12:29 +08:00
|
|
|
|
Returns:
|
2025-04-28 22:52:31 +08:00
|
|
|
|
The result of the function call
|
2025-04-28 18:12:29 +08:00
|
|
|
|
Raises:
|
|
|
|
|
TimeoutError: If the function call times out
|
2025-04-28 22:52:31 +08:00
|
|
|
|
QueueFullError: If the queue is full and waiting times out
|
|
|
|
|
Any exception raised by the decorated function
|
2025-04-28 18:12:29 +08:00
|
|
|
|
"""
|
2025-04-29 00:08:52 +08:00
|
|
|
|
# Ensure worker system is initialized
|
2025-04-28 18:12:29 +08:00
|
|
|
|
await ensure_workers()
|
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
# Create a future for the result
|
2025-04-28 18:12:29 +08:00
|
|
|
|
future = asyncio.Future()
|
2025-04-28 22:52:31 +08:00
|
|
|
|
active_futures.add(future)
|
2025-04-28 18:12:29 +08:00
|
|
|
|
|
2025-04-28 21:07:01 +08:00
|
|
|
|
nonlocal counter
|
2025-04-29 00:08:52 +08:00
|
|
|
|
async with initialization_lock:
|
2025-04-29 13:38:11 +08:00
|
|
|
|
current_count = counter # Use local variable to avoid race conditions
|
2025-04-28 21:07:01 +08:00
|
|
|
|
counter += 1
|
2025-04-28 18:12:29 +08:00
|
|
|
|
|
2025-04-28 22:52:31 +08:00
|
|
|
|
# Try to put the task into the queue, supporting timeout
|
|
|
|
|
try:
|
|
|
|
|
if _queue_timeout is not None:
|
|
|
|
|
# Use timeout to wait for queue space
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.wait_for(
|
2025-04-29 13:38:11 +08:00
|
|
|
|
# current_count is used to ensure FIFO order
|
2025-04-28 22:52:31 +08:00
|
|
|
|
queue.put((_priority, current_count, future, args, kwargs)),
|
2025-04-28 23:21:34 +08:00
|
|
|
|
timeout=_queue_timeout,
|
2025-04-28 22:52:31 +08:00
|
|
|
|
)
|
|
|
|
|
except asyncio.TimeoutError:
|
2025-04-28 23:21:34 +08:00
|
|
|
|
raise QueueFullError(
|
|
|
|
|
f"Queue full, timeout after {_queue_timeout} seconds"
|
|
|
|
|
)
|
2025-04-28 22:52:31 +08:00
|
|
|
|
else:
|
|
|
|
|
# No timeout, may wait indefinitely
|
2025-04-29 13:38:11 +08:00
|
|
|
|
# current_count is used to ensure FIFO order
|
2025-04-28 22:52:31 +08:00
|
|
|
|
await queue.put((_priority, current_count, future, args, kwargs))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Clean up the future
|
|
|
|
|
if not future.done():
|
|
|
|
|
future.set_exception(e)
|
|
|
|
|
active_futures.discard(future)
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Wait for the result, optional timeout
|
|
|
|
|
if _timeout is not None:
|
|
|
|
|
try:
|
|
|
|
|
return await asyncio.wait_for(future, _timeout)
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# Cancel the future
|
|
|
|
|
if not future.done():
|
|
|
|
|
future.cancel()
|
2025-04-28 23:21:34 +08:00
|
|
|
|
raise TimeoutError(
|
|
|
|
|
f"limit_async: Task timed out after {_timeout} seconds"
|
|
|
|
|
)
|
2025-04-28 22:52:31 +08:00
|
|
|
|
else:
|
|
|
|
|
# Wait for the result without timeout
|
|
|
|
|
return await future
|
|
|
|
|
finally:
|
|
|
|
|
# Clean up the future reference
|
|
|
|
|
active_futures.discard(future)
|
|
|
|
|
|
|
|
|
|
# Add the shutdown method to the decorated function
|
|
|
|
|
wait_func.shutdown = shutdown
|
2025-04-28 18:12:29 +08:00
|
|
|
|
|
|
|
|
|
return wait_func
|
2025-04-28 23:21:34 +08:00
|
|
|
|
|
2025-04-28 18:12:29 +08:00
|
|
|
|
return final_decro
|
|
|
|
|
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def wrap_embedding_func_with_attrs(**kwargs):
|
|
|
|
|
"""Wrap a function with attributes"""
|
|
|
|
|
|
|
|
|
|
def final_decro(func) -> EmbeddingFunc:
|
|
|
|
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
|
|
|
|
return new_func
|
|
|
|
|
|
|
|
|
|
return final_decro
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def load_json(file_name):
|
|
|
|
|
if not os.path.exists(file_name):
|
|
|
|
|
return None
|
2024-10-11 11:24:42 +08:00
|
|
|
|
with open(file_name, encoding="utf-8") as f:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return json.load(f)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def write_json(json_obj, file_name):
|
2024-10-11 11:24:42 +08:00
|
|
|
|
with open(file_name, "w", encoding="utf-8") as f:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
class TokenizerInterface(Protocol):
|
|
|
|
|
"""
|
|
|
|
|
Defines the interface for a tokenizer, requiring encode and decode methods.
|
|
|
|
|
"""
|
2025-04-18 16:14:31 +02:00
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
def encode(self, content: str) -> List[int]:
|
|
|
|
|
"""Encodes a string into a list of tokens."""
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
def decode(self, tokens: List[int]) -> str:
|
|
|
|
|
"""Decodes a list of tokens into a string."""
|
|
|
|
|
...
|
|
|
|
|
|
2025-04-18 16:14:31 +02:00
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
class Tokenizer:
|
|
|
|
|
"""
|
|
|
|
|
A wrapper around a tokenizer to provide a consistent interface for encoding and decoding.
|
|
|
|
|
"""
|
2025-04-18 16:14:31 +02:00
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
def __init__(self, model_name: str, tokenizer: TokenizerInterface):
|
|
|
|
|
"""
|
|
|
|
|
Initializes the Tokenizer with a tokenizer model name and a tokenizer instance.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_name: The associated model name for the tokenizer.
|
|
|
|
|
tokenizer: An instance of a class implementing the TokenizerInterface.
|
|
|
|
|
"""
|
|
|
|
|
self.model_name: str = model_name
|
|
|
|
|
self.tokenizer: TokenizerInterface = tokenizer
|
|
|
|
|
|
|
|
|
|
def encode(self, content: str) -> List[int]:
|
|
|
|
|
"""
|
|
|
|
|
Encodes a string into a list of tokens using the underlying tokenizer.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
content: The string to encode.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A list of integer tokens.
|
|
|
|
|
"""
|
|
|
|
|
return self.tokenizer.encode(content)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
def decode(self, tokens: List[int]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Decodes a list of tokens into a string using the underlying tokenizer.
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
Args:
|
|
|
|
|
tokens: A list of integer tokens to decode.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The decoded string.
|
|
|
|
|
"""
|
|
|
|
|
return self.tokenizer.decode(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TiktokenTokenizer(Tokenizer):
|
|
|
|
|
"""
|
|
|
|
|
A Tokenizer implementation using the tiktoken library.
|
|
|
|
|
"""
|
2025-04-18 16:14:31 +02:00
|
|
|
|
|
2025-04-17 10:56:23 +02:00
|
|
|
|
def __init__(self, model_name: str = "gpt-4o-mini"):
|
|
|
|
|
"""
|
|
|
|
|
Initializes the TiktokenTokenizer with a specified model name.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini".
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ImportError: If tiktoken is not installed.
|
|
|
|
|
ValueError: If the model_name is invalid.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
import tiktoken
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
tokenizer = tiktoken.encoding_for_model(model_name)
|
|
|
|
|
super().__init__(model_name=model_name, tokenizer=tokenizer)
|
|
|
|
|
except KeyError:
|
2025-04-18 16:14:31 +02:00
|
|
|
|
raise ValueError(f"Invalid model_name: {model_name}.")
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def pack_user_ass_to_openai_messages(*args: str):
|
|
|
|
|
roles = ["user", "assistant"]
|
|
|
|
|
return [
|
|
|
|
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
|
|
|
|
]
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
|
|
|
|
"""Split a string by multiple markers"""
|
|
|
|
|
if not markers:
|
|
|
|
|
return [content]
|
2025-04-09 11:30:29 +08:00
|
|
|
|
content = content if content is not None else ""
|
2024-10-10 15:02:30 +08:00
|
|
|
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
|
|
|
|
return [r.strip() for r in results if r.strip()]
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
# Refer the utils functions of the official GraphRAG implementation:
|
|
|
|
|
# https://github.com/microsoft/graphrag
|
|
|
|
|
def clean_str(input: Any) -> str:
|
|
|
|
|
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
|
|
|
|
# If we get non-string input, just give it back
|
|
|
|
|
if not isinstance(input, str):
|
|
|
|
|
return input
|
|
|
|
|
|
|
|
|
|
result = html.unescape(input.strip())
|
|
|
|
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
|
|
|
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def is_float_regex(value: str) -> bool:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def truncate_list_by_token_size(
|
2025-04-18 16:14:31 +02:00
|
|
|
|
list_data: list[Any],
|
|
|
|
|
key: Callable[[Any], str],
|
|
|
|
|
max_token_size: int,
|
|
|
|
|
tokenizer: Tokenizer,
|
2025-02-15 00:10:37 +01:00
|
|
|
|
) -> list[int]:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
"""Truncate a list of data by token size"""
|
|
|
|
|
if max_token_size <= 0:
|
|
|
|
|
return []
|
|
|
|
|
tokens = 0
|
|
|
|
|
for i, data in enumerate(list_data):
|
2025-04-17 10:56:23 +02:00
|
|
|
|
tokens += len(tokenizer.encode(key(data)))
|
2024-10-10 15:02:30 +08:00
|
|
|
|
if tokens > max_token_size:
|
|
|
|
|
return list_data[:i]
|
|
|
|
|
return list_data
|
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
2025-05-07 17:42:14 +08:00
|
|
|
|
def process_combine_contexts(*context_lists):
|
|
|
|
|
"""
|
|
|
|
|
Combine multiple context lists and remove duplicate content
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*context_lists: Any number of context lists
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Combined context list with duplicates removed
|
|
|
|
|
"""
|
2025-04-20 19:24:05 +08:00
|
|
|
|
seen_content = {}
|
2025-04-19 15:18:33 +08:00
|
|
|
|
combined_data = []
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
2025-05-07 17:42:14 +08:00
|
|
|
|
# Iterate through all input context lists
|
|
|
|
|
for context_list in context_lists:
|
|
|
|
|
if not context_list: # Skip empty lists
|
|
|
|
|
continue
|
|
|
|
|
for item in context_list:
|
|
|
|
|
content_dict = {k: v for k, v in item.items() if k != "id"}
|
|
|
|
|
content_key = tuple(sorted(content_dict.items()))
|
|
|
|
|
if content_key not in seen_content:
|
|
|
|
|
seen_content[content_key] = item
|
|
|
|
|
combined_data.append(item)
|
|
|
|
|
|
|
|
|
|
# Reassign IDs
|
2025-04-20 19:24:05 +08:00
|
|
|
|
for i, item in enumerate(combined_data):
|
2025-05-07 17:42:14 +08:00
|
|
|
|
item["id"] = str(i + 1)
|
2024-11-14 15:59:37 +08:00
|
|
|
|
|
2025-04-20 19:24:05 +08:00
|
|
|
|
return combined_data
|
2024-12-06 08:17:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_best_cached_response(
|
2024-12-06 14:29:16 +08:00
|
|
|
|
hashing_kv,
|
|
|
|
|
current_embedding,
|
|
|
|
|
similarity_threshold=0.95,
|
|
|
|
|
mode="default",
|
2024-12-08 17:35:52 +08:00
|
|
|
|
use_llm_check=False,
|
|
|
|
|
llm_func=None,
|
|
|
|
|
original_prompt=None,
|
2025-01-24 18:59:24 +08:00
|
|
|
|
cache_type=None,
|
2025-02-15 22:37:12 +01:00
|
|
|
|
) -> str | None:
|
2025-02-02 04:27:55 +08:00
|
|
|
|
logger.debug(
|
|
|
|
|
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
|
|
|
|
)
|
2024-12-06 14:29:16 +08:00
|
|
|
|
mode_cache = await hashing_kv.get_by_id(mode)
|
|
|
|
|
if not mode_cache:
|
2024-12-06 08:17:20 +08:00
|
|
|
|
return None
|
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
|
best_similarity = -1
|
|
|
|
|
best_response = None
|
|
|
|
|
best_prompt = None
|
|
|
|
|
best_cache_id = None
|
|
|
|
|
|
|
|
|
|
# Only iterate through cache entries for this mode
|
|
|
|
|
for cache_id, cache_data in mode_cache.items():
|
2025-01-24 18:59:24 +08:00
|
|
|
|
# Skip if cache_type doesn't match
|
|
|
|
|
if cache_type and cache_data.get("cache_type") != cache_type:
|
|
|
|
|
continue
|
|
|
|
|
|
2025-04-24 20:03:01 +08:00
|
|
|
|
# Check if cache data is valid
|
2024-12-06 14:29:16 +08:00
|
|
|
|
if cache_data["embedding"] is None:
|
|
|
|
|
continue
|
2025-04-24 20:04:42 +08:00
|
|
|
|
|
2025-04-24 20:03:01 +08:00
|
|
|
|
try:
|
|
|
|
|
# Safely convert cached embedding
|
|
|
|
|
cached_quantized = np.frombuffer(
|
|
|
|
|
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
|
|
|
|
).reshape(cache_data["embedding_shape"])
|
2025-04-24 20:04:42 +08:00
|
|
|
|
|
2025-04-24 20:03:01 +08:00
|
|
|
|
# Ensure min_val and max_val are valid float values
|
|
|
|
|
embedding_min = cache_data.get("embedding_min")
|
|
|
|
|
embedding_max = cache_data.get("embedding_max")
|
2025-04-24 20:04:42 +08:00
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
embedding_min is None
|
|
|
|
|
or embedding_max is None
|
|
|
|
|
or embedding_min >= embedding_max
|
|
|
|
|
):
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Invalid embedding min/max values: min={embedding_min}, max={embedding_max}"
|
|
|
|
|
)
|
2025-04-24 20:03:01 +08:00
|
|
|
|
continue
|
2025-04-24 20:04:42 +08:00
|
|
|
|
|
2025-04-24 20:03:01 +08:00
|
|
|
|
cached_embedding = dequantize_embedding(
|
|
|
|
|
cached_quantized,
|
|
|
|
|
embedding_min,
|
|
|
|
|
embedding_max,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Error processing cached embedding: {str(e)}")
|
|
|
|
|
continue
|
2024-12-06 14:29:16 +08:00
|
|
|
|
|
|
|
|
|
similarity = cosine_similarity(current_embedding, cached_embedding)
|
|
|
|
|
if similarity > best_similarity:
|
|
|
|
|
best_similarity = similarity
|
|
|
|
|
best_response = cache_data["return"]
|
|
|
|
|
best_prompt = cache_data["original_prompt"]
|
|
|
|
|
best_cache_id = cache_id
|
|
|
|
|
|
|
|
|
|
if best_similarity > similarity_threshold:
|
2024-12-08 17:35:52 +08:00
|
|
|
|
# If LLM check is enabled and all required parameters are provided
|
2025-02-13 14:07:36 +08:00
|
|
|
|
if (
|
|
|
|
|
use_llm_check
|
|
|
|
|
and llm_func
|
|
|
|
|
and original_prompt
|
|
|
|
|
and best_prompt
|
|
|
|
|
and best_response is not None
|
|
|
|
|
):
|
2024-12-08 17:35:52 +08:00
|
|
|
|
compare_prompt = PROMPTS["similarity_check"].format(
|
|
|
|
|
original_prompt=original_prompt, cached_prompt=best_prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
llm_result = await llm_func(compare_prompt)
|
|
|
|
|
llm_result = llm_result.strip()
|
|
|
|
|
llm_similarity = float(llm_result)
|
|
|
|
|
|
|
|
|
|
# Replace vector similarity with LLM similarity score
|
|
|
|
|
best_similarity = llm_similarity
|
|
|
|
|
if best_similarity < similarity_threshold:
|
|
|
|
|
log_data = {
|
2025-02-13 13:53:52 +08:00
|
|
|
|
"event": "cache_rejected_by_llm",
|
|
|
|
|
"type": cache_type,
|
|
|
|
|
"mode": mode,
|
2024-12-08 17:35:52 +08:00
|
|
|
|
"original_question": original_prompt[:100] + "..."
|
|
|
|
|
if len(original_prompt) > 100
|
|
|
|
|
else original_prompt,
|
|
|
|
|
"cached_question": best_prompt[:100] + "..."
|
|
|
|
|
if len(best_prompt) > 100
|
|
|
|
|
else best_prompt,
|
|
|
|
|
"similarity_score": round(best_similarity, 4),
|
|
|
|
|
"threshold": similarity_threshold,
|
|
|
|
|
}
|
2025-02-13 13:53:52 +08:00
|
|
|
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
|
|
|
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
2024-12-08 17:35:52 +08:00
|
|
|
|
return None
|
|
|
|
|
except Exception as e: # Catch all possible exceptions
|
|
|
|
|
logger.warning(f"LLM similarity check failed: {e}")
|
|
|
|
|
return None # Return None directly when LLM check fails
|
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
|
prompt_display = (
|
|
|
|
|
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
|
|
|
|
)
|
|
|
|
|
log_data = {
|
|
|
|
|
"event": "cache_hit",
|
2025-02-13 13:53:52 +08:00
|
|
|
|
"type": cache_type,
|
2024-12-06 14:29:16 +08:00
|
|
|
|
"mode": mode,
|
|
|
|
|
"similarity": round(best_similarity, 4),
|
|
|
|
|
"cache_id": best_cache_id,
|
|
|
|
|
"original_prompt": prompt_display,
|
|
|
|
|
}
|
2025-02-13 13:53:52 +08:00
|
|
|
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
2024-12-06 14:29:16 +08:00
|
|
|
|
return best_response
|
|
|
|
|
return None
|
2024-12-06 08:17:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cosine_similarity(v1, v2):
|
|
|
|
|
"""Calculate cosine similarity between two vectors"""
|
|
|
|
|
dot_product = np.dot(v1, v2)
|
|
|
|
|
norm1 = np.linalg.norm(v1)
|
|
|
|
|
norm2 = np.linalg.norm(v2)
|
|
|
|
|
return dot_product / (norm1 * norm2)
|
|
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:32 +01:00
|
|
|
|
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
2024-12-06 08:17:20 +08:00
|
|
|
|
"""Quantize embedding to specified bits"""
|
2025-02-01 09:55:05 +08:00
|
|
|
|
# Convert list to numpy array if needed
|
|
|
|
|
if isinstance(embedding, list):
|
|
|
|
|
embedding = np.array(embedding)
|
2025-02-01 15:22:40 +08:00
|
|
|
|
|
2024-12-06 08:17:20 +08:00
|
|
|
|
# Calculate min/max values for reconstruction
|
|
|
|
|
min_val = embedding.min()
|
|
|
|
|
max_val = embedding.max()
|
|
|
|
|
|
2025-04-24 20:03:01 +08:00
|
|
|
|
if min_val == max_val:
|
|
|
|
|
# handle constant vector
|
|
|
|
|
quantized = np.zeros_like(embedding, dtype=np.uint8)
|
|
|
|
|
return quantized, min_val, max_val
|
|
|
|
|
|
2024-12-06 08:17:20 +08:00
|
|
|
|
# Quantize to 0-255 range
|
|
|
|
|
scale = (2**bits - 1) / (max_val - min_val)
|
|
|
|
|
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
return quantized, min_val, max_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dequantize_embedding(
|
|
|
|
|
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
"""Restore quantized embedding"""
|
2025-04-24 20:03:01 +08:00
|
|
|
|
if min_val == max_val:
|
|
|
|
|
# handle constant vector
|
|
|
|
|
return np.full_like(quantized, min_val, dtype=np.float32)
|
2025-04-24 20:04:42 +08:00
|
|
|
|
|
2024-12-06 08:17:20 +08:00
|
|
|
|
scale = (max_val - min_val) / (2**bits - 1)
|
|
|
|
|
return (quantized * scale + min_val).astype(np.float32)
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
2024-12-08 17:35:52 +08:00
|
|
|
|
|
2025-01-31 15:33:50 +08:00
|
|
|
|
async def handle_cache(
|
2025-02-02 01:56:32 +08:00
|
|
|
|
hashing_kv,
|
|
|
|
|
args_hash,
|
|
|
|
|
prompt,
|
|
|
|
|
mode="default",
|
|
|
|
|
cache_type=None,
|
2025-01-31 15:33:50 +08:00
|
|
|
|
):
|
2024-12-08 10:37:55 +08:00
|
|
|
|
"""Generic cache handling function"""
|
2025-03-09 22:15:26 +08:00
|
|
|
|
if hashing_kv is None:
|
2024-12-08 10:37:55 +08:00
|
|
|
|
return None, None, None, None
|
|
|
|
|
|
2025-03-09 22:15:26 +08:00
|
|
|
|
if mode != "default": # handle cache for all type of query
|
|
|
|
|
if not hashing_kv.global_config.get("enable_llm_cache"):
|
|
|
|
|
return None, None, None, None
|
|
|
|
|
else: # handle cache for entity extraction
|
|
|
|
|
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
|
|
|
|
return None, None, None, None
|
|
|
|
|
|
2025-02-02 00:10:21 +08:00
|
|
|
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
|
|
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
2024-12-08 10:37:55 +08:00
|
|
|
|
else:
|
2025-02-02 00:10:21 +08:00
|
|
|
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
|
|
|
|
if args_hash in mode_cache:
|
2025-03-09 00:59:40 +08:00
|
|
|
|
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
return mode_cache[args_hash]["return"], None, None, None
|
|
|
|
|
|
2025-03-09 00:59:40 +08:00
|
|
|
|
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
return None, None, None, None
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CacheData:
|
|
|
|
|
args_hash: str
|
|
|
|
|
content: str
|
|
|
|
|
prompt: str
|
2025-02-15 22:37:12 +01:00
|
|
|
|
quantized: np.ndarray | None = None
|
|
|
|
|
min_val: float | None = None
|
|
|
|
|
max_val: float | None = None
|
2024-12-08 10:37:55 +08:00
|
|
|
|
mode: str = "default"
|
2025-01-25 00:55:07 +01:00
|
|
|
|
cache_type: str = "query"
|
2025-06-09 18:52:34 +08:00
|
|
|
|
chunk_id: str | None = None
|
2025-01-25 00:55:07 +01:00
|
|
|
|
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
|
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
2025-03-10 01:45:58 +08:00
|
|
|
|
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
Args:
|
|
|
|
|
hashing_kv: The key-value storage for caching
|
|
|
|
|
cache_data: The cache data to save
|
|
|
|
|
"""
|
|
|
|
|
# Skip if storage is None or content is a streaming response
|
|
|
|
|
if hashing_kv is None or not cache_data.content:
|
2024-12-08 10:37:55 +08:00
|
|
|
|
return
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# If content is a streaming response, don't cache it
|
|
|
|
|
if hasattr(cache_data.content, "__aiter__"):
|
|
|
|
|
logger.debug("Streaming response detected, skipping cache")
|
|
|
|
|
return
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Get existing cache data
|
2025-01-06 12:50:05 +08:00
|
|
|
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
|
|
|
mode_cache = (
|
|
|
|
|
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
|
|
|
|
or {}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Check if we already have identical content cached
|
|
|
|
|
if cache_data.args_hash in mode_cache:
|
|
|
|
|
existing_content = mode_cache[cache_data.args_hash].get("return")
|
|
|
|
|
if existing_content == cache_data.content:
|
2025-03-10 02:07:19 +08:00
|
|
|
|
logger.info(
|
|
|
|
|
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
|
|
|
|
)
|
2025-03-10 01:45:58 +08:00
|
|
|
|
return
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Update cache with new content
|
2024-12-08 10:37:55 +08:00
|
|
|
|
mode_cache[cache_data.args_hash] = {
|
|
|
|
|
"return": cache_data.content,
|
2025-02-01 22:27:49 +08:00
|
|
|
|
"cache_type": cache_data.cache_type,
|
2025-06-09 18:52:34 +08:00
|
|
|
|
"chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None,
|
2024-12-08 10:37:55 +08:00
|
|
|
|
"embedding": cache_data.quantized.tobytes().hex()
|
|
|
|
|
if cache_data.quantized is not None
|
|
|
|
|
else None,
|
|
|
|
|
"embedding_shape": cache_data.quantized.shape
|
|
|
|
|
if cache_data.quantized is not None
|
|
|
|
|
else None,
|
|
|
|
|
"embedding_min": cache_data.min_val,
|
|
|
|
|
"embedding_max": cache_data.max_val,
|
|
|
|
|
"original_prompt": cache_data.prompt,
|
|
|
|
|
}
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-04-16 01:24:59 +08:00
|
|
|
|
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
|
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Only upsert if there's actual new content
|
2024-12-08 10:37:55 +08:00
|
|
|
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
2024-12-09 19:14:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_unicode_decode(content):
|
|
|
|
|
# Regular expression to find all Unicode escape sequences of the form \uXXXX
|
|
|
|
|
unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
|
|
|
|
|
|
|
|
|
|
# Function to replace the Unicode escape with the actual character
|
|
|
|
|
def replace_unicode_escape(match):
|
|
|
|
|
# Convert the matched hexadecimal value into the actual Unicode character
|
|
|
|
|
return chr(int(match.group(1), 16))
|
|
|
|
|
|
|
|
|
|
# Perform the substitution
|
|
|
|
|
decoded_content = unicode_escape_pattern.sub(
|
|
|
|
|
replace_unicode_escape, content.decode("utf-8")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return decoded_content
|
2025-01-06 12:50:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists_func(obj, func_name: str) -> bool:
|
|
|
|
|
"""Check if a function exists in an object or not.
|
|
|
|
|
:param obj:
|
|
|
|
|
:param func_name:
|
|
|
|
|
:return: True / False
|
|
|
|
|
"""
|
|
|
|
|
if callable(getattr(obj, func_name, None)):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return False
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def get_conversation_turns(
|
|
|
|
|
conversation_history: list[dict[str, Any]], num_turns: int
|
|
|
|
|
) -> str:
|
2025-01-24 18:59:24 +08:00
|
|
|
|
"""
|
|
|
|
|
Process conversation history to get the specified number of complete turns.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
conversation_history: List of conversation messages in chronological order
|
|
|
|
|
num_turns: Number of complete turns to include
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Formatted string of the conversation history
|
|
|
|
|
"""
|
2025-02-17 12:32:04 +08:00
|
|
|
|
# Check if num_turns is valid
|
|
|
|
|
if num_turns <= 0:
|
|
|
|
|
return ""
|
|
|
|
|
|
2025-01-24 18:59:24 +08:00
|
|
|
|
# Group messages into turns
|
2025-02-15 00:10:37 +01:00
|
|
|
|
turns: list[list[dict[str, Any]]] = []
|
|
|
|
|
messages: list[dict[str, Any]] = []
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
|
# First, filter out keyword extraction messages
|
|
|
|
|
for msg in conversation_history:
|
|
|
|
|
if msg["role"] == "assistant" and (
|
|
|
|
|
msg["content"].startswith('{ "high_level_keywords"')
|
|
|
|
|
or msg["content"].startswith("{'high_level_keywords'")
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
messages.append(msg)
|
|
|
|
|
|
|
|
|
|
# Then process messages in chronological order
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(messages) - 1:
|
|
|
|
|
msg1 = messages[i]
|
|
|
|
|
msg2 = messages[i + 1]
|
|
|
|
|
|
|
|
|
|
# Check if we have a user-assistant or assistant-user pair
|
|
|
|
|
if (msg1["role"] == "user" and msg2["role"] == "assistant") or (
|
|
|
|
|
msg1["role"] == "assistant" and msg2["role"] == "user"
|
|
|
|
|
):
|
|
|
|
|
# Always put user message first in the turn
|
|
|
|
|
if msg1["role"] == "assistant":
|
|
|
|
|
turn = [msg2, msg1] # user, assistant
|
|
|
|
|
else:
|
|
|
|
|
turn = [msg1, msg2] # user, assistant
|
|
|
|
|
turns.append(turn)
|
2025-01-30 13:08:27 +08:00
|
|
|
|
i += 2
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
|
# Keep only the most recent num_turns
|
|
|
|
|
if len(turns) > num_turns:
|
|
|
|
|
turns = turns[-num_turns:]
|
|
|
|
|
|
|
|
|
|
# Format the turns into a string
|
2025-02-15 00:10:37 +01:00
|
|
|
|
formatted_turns: list[str] = []
|
2025-01-24 18:59:24 +08:00
|
|
|
|
for turn in turns:
|
|
|
|
|
formatted_turns.extend(
|
|
|
|
|
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return "\n".join(formatted_turns)
|
2025-02-20 13:18:17 +01:00
|
|
|
|
|
2025-02-20 13:21:41 +01:00
|
|
|
|
|
2025-02-20 13:18:17 +01:00
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
|
|
|
"""
|
|
|
|
|
Ensure that there is always an event loop available.
|
|
|
|
|
|
|
|
|
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
|
|
|
it creates a new event loop and sets it as the current event loop.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Try to get the current event loop
|
|
|
|
|
current_loop = asyncio.get_event_loop()
|
|
|
|
|
if current_loop.is_closed():
|
|
|
|
|
raise RuntimeError("Event loop is closed.")
|
|
|
|
|
return current_loop
|
|
|
|
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
# If no event loop exists or it is closed, create a new one
|
|
|
|
|
logger.info("Creating a new event loop in main thread.")
|
|
|
|
|
new_loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(new_loop)
|
|
|
|
|
return new_loop
|
2025-02-20 13:21:41 +01:00
|
|
|
|
|
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
async def aexport_data(
|
|
|
|
|
chunk_entity_relation_graph,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
output_path: str,
|
|
|
|
|
file_format: str = "csv",
|
|
|
|
|
include_vector_data: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Asynchronously exports all entities, relations, and relationships to various formats.
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
Args:
|
|
|
|
|
chunk_entity_relation_graph: Graph storage instance for entities and relations
|
|
|
|
|
entities_vdb: Vector database storage for entities
|
|
|
|
|
relationships_vdb: Vector database storage for relationships
|
|
|
|
|
output_path: The path to the output file (including extension).
|
|
|
|
|
file_format: Output format - "csv", "excel", "md", "txt".
|
|
|
|
|
- csv: Comma-separated values file
|
|
|
|
|
- excel: Microsoft Excel file with multiple sheets
|
|
|
|
|
- md: Markdown tables
|
|
|
|
|
- txt: Plain text formatted output
|
|
|
|
|
include_vector_data: Whether to include data from the vector database.
|
|
|
|
|
"""
|
|
|
|
|
# Collect data
|
|
|
|
|
entities_data = []
|
|
|
|
|
relations_data = []
|
|
|
|
|
relationships_data = []
|
|
|
|
|
|
|
|
|
|
# --- Entities ---
|
|
|
|
|
all_entities = await chunk_entity_relation_graph.get_all_labels()
|
|
|
|
|
for entity_name in all_entities:
|
|
|
|
|
# Get entity information from graph
|
|
|
|
|
node_data = await chunk_entity_relation_graph.get_node(entity_name)
|
|
|
|
|
source_id = node_data.get("source_id") if node_data else None
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
entity_info = {
|
|
|
|
|
"graph_data": node_data,
|
|
|
|
|
"source_id": source_id,
|
|
|
|
|
}
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
# Optional: Get vector database information
|
|
|
|
|
if include_vector_data:
|
|
|
|
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
|
|
|
vector_data = await entities_vdb.get_by_id(entity_id)
|
|
|
|
|
entity_info["vector_data"] = vector_data
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
entity_row = {
|
|
|
|
|
"entity_name": entity_name,
|
|
|
|
|
"source_id": source_id,
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"graph_data": str(
|
|
|
|
|
entity_info["graph_data"]
|
|
|
|
|
), # Convert to string to ensure compatibility
|
2025-04-14 03:06:23 +08:00
|
|
|
|
}
|
|
|
|
|
if include_vector_data and "vector_data" in entity_info:
|
|
|
|
|
entity_row["vector_data"] = str(entity_info["vector_data"])
|
|
|
|
|
entities_data.append(entity_row)
|
|
|
|
|
|
|
|
|
|
# --- Relations ---
|
|
|
|
|
for src_entity in all_entities:
|
|
|
|
|
for tgt_entity in all_entities:
|
|
|
|
|
if src_entity == tgt_entity:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
edge_exists = await chunk_entity_relation_graph.has_edge(
|
|
|
|
|
src_entity, tgt_entity
|
|
|
|
|
)
|
|
|
|
|
if edge_exists:
|
|
|
|
|
# Get edge information from graph
|
|
|
|
|
edge_data = await chunk_entity_relation_graph.get_edge(
|
|
|
|
|
src_entity, tgt_entity
|
|
|
|
|
)
|
|
|
|
|
source_id = edge_data.get("source_id") if edge_data else None
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
relation_info = {
|
|
|
|
|
"graph_data": edge_data,
|
|
|
|
|
"source_id": source_id,
|
|
|
|
|
}
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
# Optional: Get vector database information
|
|
|
|
|
if include_vector_data:
|
|
|
|
|
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
|
|
|
|
vector_data = await relationships_vdb.get_by_id(rel_id)
|
|
|
|
|
relation_info["vector_data"] = vector_data
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
relation_row = {
|
|
|
|
|
"src_entity": src_entity,
|
|
|
|
|
"tgt_entity": tgt_entity,
|
|
|
|
|
"source_id": relation_info["source_id"],
|
|
|
|
|
"graph_data": str(relation_info["graph_data"]), # Convert to string
|
|
|
|
|
}
|
|
|
|
|
if include_vector_data and "vector_data" in relation_info:
|
|
|
|
|
relation_row["vector_data"] = str(relation_info["vector_data"])
|
|
|
|
|
relations_data.append(relation_row)
|
|
|
|
|
|
|
|
|
|
# --- Relationships (from VectorDB) ---
|
|
|
|
|
all_relationships = await relationships_vdb.client_storage
|
|
|
|
|
for rel in all_relationships["data"]:
|
|
|
|
|
relationships_data.append(
|
|
|
|
|
{
|
|
|
|
|
"relationship_id": rel["__id__"],
|
|
|
|
|
"data": str(rel), # Convert to string for compatibility
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Export based on format
|
|
|
|
|
if file_format == "csv":
|
|
|
|
|
# CSV export
|
|
|
|
|
with open(output_path, "w", newline="", encoding="utf-8") as csvfile:
|
|
|
|
|
# Entities
|
|
|
|
|
if entities_data:
|
|
|
|
|
csvfile.write("# ENTITIES\n")
|
|
|
|
|
writer = csv.DictWriter(csvfile, fieldnames=entities_data[0].keys())
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
writer.writerows(entities_data)
|
|
|
|
|
csvfile.write("\n\n")
|
|
|
|
|
|
|
|
|
|
# Relations
|
|
|
|
|
if relations_data:
|
|
|
|
|
csvfile.write("# RELATIONS\n")
|
2025-04-14 12:08:56 +08:00
|
|
|
|
writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys())
|
2025-04-14 03:06:23 +08:00
|
|
|
|
writer.writeheader()
|
|
|
|
|
writer.writerows(relations_data)
|
|
|
|
|
csvfile.write("\n\n")
|
|
|
|
|
|
|
|
|
|
# Relationships
|
|
|
|
|
if relationships_data:
|
|
|
|
|
csvfile.write("# RELATIONSHIPS\n")
|
|
|
|
|
writer = csv.DictWriter(
|
|
|
|
|
csvfile, fieldnames=relationships_data[0].keys()
|
|
|
|
|
)
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
writer.writerows(relationships_data)
|
|
|
|
|
|
|
|
|
|
elif file_format == "excel":
|
|
|
|
|
# Excel export
|
|
|
|
|
import pandas as pd
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
|
|
|
|
entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
|
2025-04-14 03:06:23 +08:00
|
|
|
|
relations_df = (
|
|
|
|
|
pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
|
|
|
|
|
)
|
|
|
|
|
relationships_df = (
|
2025-04-14 12:08:56 +08:00
|
|
|
|
pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame()
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
|
|
|
|
|
if not entities_df.empty:
|
|
|
|
|
entities_df.to_excel(writer, sheet_name="Entities", index=False)
|
|
|
|
|
if not relations_df.empty:
|
|
|
|
|
relations_df.to_excel(writer, sheet_name="Relations", index=False)
|
|
|
|
|
if not relationships_df.empty:
|
|
|
|
|
relationships_df.to_excel(
|
|
|
|
|
writer, sheet_name="Relationships", index=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif file_format == "md":
|
|
|
|
|
# Markdown export
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as mdfile:
|
|
|
|
|
mdfile.write("# LightRAG Data Export\n\n")
|
|
|
|
|
|
|
|
|
|
# Entities
|
|
|
|
|
mdfile.write("## Entities\n\n")
|
|
|
|
|
if entities_data:
|
|
|
|
|
# Write header
|
|
|
|
|
mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
|
|
|
|
|
mdfile.write(
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n"
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for entity in entities_data:
|
|
|
|
|
mdfile.write(
|
|
|
|
|
"| " + " | ".join(str(v) for v in entity.values()) + " |\n"
|
|
|
|
|
)
|
|
|
|
|
mdfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
mdfile.write("*No entity data available*\n\n")
|
|
|
|
|
|
|
|
|
|
# Relations
|
|
|
|
|
mdfile.write("## Relations\n\n")
|
|
|
|
|
if relations_data:
|
|
|
|
|
# Write header
|
|
|
|
|
mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
|
|
|
|
|
mdfile.write(
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n"
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relation in relations_data:
|
|
|
|
|
mdfile.write(
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"| " + " | ".join(str(v) for v in relation.values()) + " |\n"
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
mdfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
mdfile.write("*No relation data available*\n\n")
|
|
|
|
|
|
|
|
|
|
# Relationships
|
|
|
|
|
mdfile.write("## Relationships\n\n")
|
|
|
|
|
if relationships_data:
|
|
|
|
|
# Write header
|
2025-04-14 12:08:56 +08:00
|
|
|
|
mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n")
|
2025-04-14 03:06:23 +08:00
|
|
|
|
mdfile.write(
|
|
|
|
|
"| "
|
|
|
|
|
+ " | ".join(["---"] * len(relationships_data[0].keys()))
|
|
|
|
|
+ " |\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relationship in relationships_data:
|
|
|
|
|
mdfile.write(
|
|
|
|
|
"| "
|
|
|
|
|
+ " | ".join(str(v) for v in relationship.values())
|
|
|
|
|
+ " |\n"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
mdfile.write("*No relationship data available*\n\n")
|
|
|
|
|
|
|
|
|
|
elif file_format == "txt":
|
|
|
|
|
# Plain text export
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as txtfile:
|
|
|
|
|
txtfile.write("LIGHTRAG DATA EXPORT\n")
|
|
|
|
|
txtfile.write("=" * 80 + "\n\n")
|
|
|
|
|
|
|
|
|
|
# Entities
|
|
|
|
|
txtfile.write("ENTITIES\n")
|
|
|
|
|
txtfile.write("-" * 80 + "\n")
|
|
|
|
|
if entities_data:
|
|
|
|
|
# Create fixed width columns
|
|
|
|
|
col_widths = {
|
|
|
|
|
k: max(len(k), max(len(str(e[k])) for e in entities_data))
|
|
|
|
|
for k in entities_data[0]
|
|
|
|
|
}
|
|
|
|
|
header = " ".join(k.ljust(col_widths[k]) for k in entities_data[0])
|
|
|
|
|
txtfile.write(header + "\n")
|
|
|
|
|
txtfile.write("-" * len(header) + "\n")
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for entity in entities_data:
|
|
|
|
|
row = " ".join(
|
|
|
|
|
str(v).ljust(col_widths[k]) for k, v in entity.items()
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(row + "\n")
|
|
|
|
|
txtfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
txtfile.write("No entity data available\n\n")
|
|
|
|
|
|
|
|
|
|
# Relations
|
|
|
|
|
txtfile.write("RELATIONS\n")
|
|
|
|
|
txtfile.write("-" * 80 + "\n")
|
|
|
|
|
if relations_data:
|
|
|
|
|
# Create fixed width columns
|
|
|
|
|
col_widths = {
|
|
|
|
|
k: max(len(k), max(len(str(r[k])) for r in relations_data))
|
|
|
|
|
for k in relations_data[0]
|
|
|
|
|
}
|
2025-04-14 12:08:56 +08:00
|
|
|
|
header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0])
|
2025-04-14 03:06:23 +08:00
|
|
|
|
txtfile.write(header + "\n")
|
|
|
|
|
txtfile.write("-" * len(header) + "\n")
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relation in relations_data:
|
|
|
|
|
row = " ".join(
|
|
|
|
|
str(v).ljust(col_widths[k]) for k, v in relation.items()
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(row + "\n")
|
|
|
|
|
txtfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
txtfile.write("No relation data available\n\n")
|
|
|
|
|
|
|
|
|
|
# Relationships
|
|
|
|
|
txtfile.write("RELATIONSHIPS\n")
|
|
|
|
|
txtfile.write("-" * 80 + "\n")
|
|
|
|
|
if relationships_data:
|
|
|
|
|
# Create fixed width columns
|
|
|
|
|
col_widths = {
|
|
|
|
|
k: max(len(k), max(len(str(r[k])) for r in relationships_data))
|
|
|
|
|
for k in relationships_data[0]
|
|
|
|
|
}
|
|
|
|
|
header = " ".join(
|
|
|
|
|
k.ljust(col_widths[k]) for k in relationships_data[0]
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(header + "\n")
|
|
|
|
|
txtfile.write("-" * len(header) + "\n")
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relationship in relationships_data:
|
|
|
|
|
row = " ".join(
|
|
|
|
|
str(v).ljust(col_widths[k]) for k, v in relationship.items()
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(row + "\n")
|
|
|
|
|
else:
|
|
|
|
|
txtfile.write("No relationship data available\n\n")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unsupported file format: {file_format}. "
|
|
|
|
|
f"Choose from: csv, excel, md, txt"
|
|
|
|
|
)
|
|
|
|
|
if file_format is not None:
|
|
|
|
|
print(f"Data exported to: {output_path} with format: {file_format}")
|
|
|
|
|
else:
|
|
|
|
|
print("Data displayed as table format")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_data(
|
|
|
|
|
chunk_entity_relation_graph,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
output_path: str,
|
|
|
|
|
file_format: str = "csv",
|
|
|
|
|
include_vector_data: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Synchronously exports all entities, relations, and relationships to various formats.
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
Args:
|
|
|
|
|
chunk_entity_relation_graph: Graph storage instance for entities and relations
|
|
|
|
|
entities_vdb: Vector database storage for entities
|
|
|
|
|
relationships_vdb: Vector database storage for relationships
|
|
|
|
|
output_path: The path to the output file (including extension).
|
|
|
|
|
file_format: Output format - "csv", "excel", "md", "txt".
|
|
|
|
|
- csv: Comma-separated values file
|
|
|
|
|
- excel: Microsoft Excel file with multiple sheets
|
|
|
|
|
- md: Markdown tables
|
|
|
|
|
- txt: Plain text formatted output
|
|
|
|
|
include_vector_data: Whether to include data from the vector database.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
|
|
|
|
|
loop.run_until_complete(
|
|
|
|
|
aexport_data(
|
|
|
|
|
chunk_entity_relation_graph,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
output_path,
|
|
|
|
|
file_format,
|
2025-04-14 12:08:56 +08:00
|
|
|
|
include_vector_data,
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-02-20 13:18:17 +01:00
|
|
|
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
|
|
|
|
"""Lazily import a class from an external module based on the package of the caller."""
|
|
|
|
|
# Get the caller's module and package
|
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
|
|
caller_frame = inspect.currentframe().f_back
|
|
|
|
|
module = inspect.getmodule(caller_frame)
|
|
|
|
|
package = module.__package__ if module else None
|
|
|
|
|
|
|
|
|
|
def import_class(*args: Any, **kwargs: Any):
|
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
|
|
module = importlib.import_module(module_name, package=package)
|
|
|
|
|
cls = getattr(module, class_name)
|
|
|
|
|
return cls(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return import_class
|
2025-03-11 15:43:04 +08:00
|
|
|
|
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
async def use_llm_func_with_cache(
|
2025-04-10 03:58:04 +08:00
|
|
|
|
input_text: str,
|
2025-04-10 03:57:36 +08:00
|
|
|
|
use_llm_func: callable,
|
2025-04-10 03:58:04 +08:00
|
|
|
|
llm_response_cache: "BaseKVStorage | None" = None,
|
2025-04-10 03:57:36 +08:00
|
|
|
|
max_tokens: int = None,
|
|
|
|
|
history_messages: list[dict[str, str]] = None,
|
2025-04-10 03:58:04 +08:00
|
|
|
|
cache_type: str = "extract",
|
2025-06-09 18:52:34 +08:00
|
|
|
|
chunk_id: str | None = None,
|
2025-04-10 03:57:36 +08:00
|
|
|
|
) -> str:
|
|
|
|
|
"""Call LLM function with cache support
|
2025-04-10 03:58:04 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
If cache is available and enabled (determined by handle_cache based on mode),
|
|
|
|
|
retrieve result from cache; otherwise call LLM function and save result to cache.
|
2025-04-10 03:58:04 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
Args:
|
|
|
|
|
input_text: Input text to send to LLM
|
2025-04-28 18:51:43 +08:00
|
|
|
|
use_llm_func: LLM function with higher priority
|
2025-04-10 03:57:36 +08:00
|
|
|
|
llm_response_cache: Cache storage instance
|
|
|
|
|
max_tokens: Maximum tokens for generation
|
|
|
|
|
history_messages: History messages list
|
|
|
|
|
cache_type: Type of cache
|
2025-06-09 18:52:34 +08:00
|
|
|
|
chunk_id: Chunk identifier to store in cache
|
2025-04-10 03:58:04 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
Returns:
|
|
|
|
|
LLM response text
|
|
|
|
|
"""
|
|
|
|
|
if llm_response_cache:
|
|
|
|
|
if history_messages:
|
|
|
|
|
history = json.dumps(history_messages, ensure_ascii=False)
|
|
|
|
|
_prompt = history + "\n" + input_text
|
|
|
|
|
else:
|
|
|
|
|
_prompt = input_text
|
|
|
|
|
|
|
|
|
|
arg_hash = compute_args_hash(_prompt)
|
|
|
|
|
cached_return, _1, _2, _3 = await handle_cache(
|
|
|
|
|
llm_response_cache,
|
|
|
|
|
arg_hash,
|
|
|
|
|
_prompt,
|
|
|
|
|
"default",
|
|
|
|
|
cache_type=cache_type,
|
|
|
|
|
)
|
|
|
|
|
if cached_return:
|
|
|
|
|
logger.debug(f"Found cache for {arg_hash}")
|
|
|
|
|
statistic_data["llm_cache"] += 1
|
|
|
|
|
return cached_return
|
|
|
|
|
statistic_data["llm_call"] += 1
|
|
|
|
|
|
|
|
|
|
# Call LLM
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if history_messages:
|
|
|
|
|
kwargs["history_messages"] = history_messages
|
|
|
|
|
if max_tokens is not None:
|
|
|
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
|
|
|
|
|
|
res: str = await use_llm_func(input_text, **kwargs)
|
|
|
|
|
|
2025-04-16 01:24:59 +08:00
|
|
|
|
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
llm_response_cache,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=arg_hash,
|
|
|
|
|
content=res,
|
|
|
|
|
prompt=_prompt,
|
|
|
|
|
cache_type=cache_type,
|
2025-06-09 18:52:34 +08:00
|
|
|
|
chunk_id=chunk_id,
|
2025-04-16 01:24:59 +08:00
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
# When cache is disabled, directly call LLM
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if history_messages:
|
|
|
|
|
kwargs["history_messages"] = history_messages
|
|
|
|
|
if max_tokens is not None:
|
|
|
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
|
|
|
|
|
|
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
|
|
|
|
|
return await use_llm_func(input_text, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
2025-03-25 23:39:09 +08:00
|
|
|
|
def get_content_summary(content: str, max_length: int = 250) -> str:
|
2025-03-11 15:43:04 +08:00
|
|
|
|
"""Get summary of document content
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
content: Original document content
|
|
|
|
|
max_length: Maximum length of summary
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Truncated content with ellipsis if needed
|
|
|
|
|
"""
|
|
|
|
|
content = content.strip()
|
|
|
|
|
if len(content) <= max_length:
|
|
|
|
|
return content
|
|
|
|
|
return content[:max_length] + "..."
|
|
|
|
|
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-04-12 20:50:21 +08:00
|
|
|
|
def normalize_extracted_info(name: str, is_entity=False) -> str:
|
2025-04-12 19:26:02 +08:00
|
|
|
|
"""Normalize entity/relation names and description with the following rules:
|
|
|
|
|
1. Remove spaces between Chinese characters
|
|
|
|
|
2. Remove spaces between Chinese characters and English letters/numbers
|
|
|
|
|
3. Preserve spaces within English text and numbers
|
|
|
|
|
4. Replace Chinese parentheses with English parentheses
|
|
|
|
|
5. Replace Chinese dash with English dash
|
2025-05-05 11:26:31 +08:00
|
|
|
|
6. Remove English quotation marks from the beginning and end of the text
|
|
|
|
|
7. Remove English quotation marks in and around chinese
|
|
|
|
|
8. Remove Chinese quotation marks
|
2025-04-12 19:26:02 +08:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Entity name to normalize
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Normalized entity name
|
|
|
|
|
"""
|
|
|
|
|
# Replace Chinese parentheses with English parentheses
|
|
|
|
|
name = name.replace("(", "(").replace(")", ")")
|
|
|
|
|
|
|
|
|
|
# Replace Chinese dash with English dash
|
|
|
|
|
name = name.replace("—", "-").replace("-", "-")
|
|
|
|
|
|
|
|
|
|
# Use regex to remove spaces between Chinese characters
|
|
|
|
|
# Regex explanation:
|
|
|
|
|
# (?<=[\u4e00-\u9fa5]): Positive lookbehind for Chinese character
|
|
|
|
|
# \s+: One or more whitespace characters
|
|
|
|
|
# (?=[\u4e00-\u9fa5]): Positive lookahead for Chinese character
|
|
|
|
|
name = re.sub(r"(?<=[\u4e00-\u9fa5])\s+(?=[\u4e00-\u9fa5])", "", name)
|
|
|
|
|
|
2025-04-21 19:21:30 +08:00
|
|
|
|
# Remove spaces between Chinese and English/numbers/symbols
|
2025-04-21 20:18:05 +08:00
|
|
|
|
name = re.sub(
|
|
|
|
|
r"(?<=[\u4e00-\u9fa5])\s+(?=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])", "", name
|
|
|
|
|
)
|
|
|
|
|
name = re.sub(
|
|
|
|
|
r"(?<=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])\s+(?=[\u4e00-\u9fa5])", "", name
|
|
|
|
|
)
|
2025-04-12 19:26:02 +08:00
|
|
|
|
|
|
|
|
|
# Remove English quotation marks from the beginning and end
|
2025-04-17 23:00:34 +08:00
|
|
|
|
if len(name) >= 2 and name.startswith('"') and name.endswith('"'):
|
2025-04-17 22:58:36 +08:00
|
|
|
|
name = name[1:-1]
|
2025-04-23 21:30:07 +08:00
|
|
|
|
if len(name) >= 2 and name.startswith("'") and name.endswith("'"):
|
|
|
|
|
name = name[1:-1]
|
2025-04-12 19:26:02 +08:00
|
|
|
|
|
2025-04-12 20:45:41 +08:00
|
|
|
|
if is_entity:
|
|
|
|
|
# remove Chinese quotes
|
|
|
|
|
name = name.replace("“", "").replace("”", "").replace("‘", "").replace("’", "")
|
|
|
|
|
# remove English queotes in and around chinese
|
|
|
|
|
name = re.sub(r"['\"]+(?=[\u4e00-\u9fa5])", "", name)
|
|
|
|
|
name = re.sub(r"(?<=[\u4e00-\u9fa5])['\"]+", "", name)
|
|
|
|
|
|
2025-04-12 19:26:02 +08:00
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
2025-03-11 15:43:04 +08:00
|
|
|
|
def clean_text(text: str) -> str:
|
|
|
|
|
"""Clean text by removing null bytes (0x00) and whitespace
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: Input text to clean
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Cleaned text
|
|
|
|
|
"""
|
|
|
|
|
return text.strip().replace("\x00", "")
|
|
|
|
|
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-03-11 15:43:04 +08:00
|
|
|
|
def check_storage_env_vars(storage_name: str) -> None:
|
|
|
|
|
"""Check if all required environment variables for storage implementation exist
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
storage_name: Storage implementation name
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If required environment variables are missing
|
|
|
|
|
"""
|
|
|
|
|
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-03-11 15:43:04 +08:00
|
|
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
|
|
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
|
|
|
|
|
|
|
|
if missing_vars:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Storage implementation '{storage_name}' requires the following "
|
|
|
|
|
f"environment variables: {', '.join(missing_vars)}"
|
2025-03-11 15:44:01 +08:00
|
|
|
|
)
|
2025-03-28 01:25:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenTracker:
|
|
|
|
|
"""Track token usage for LLM calls."""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.reset()
|
|
|
|
|
|
2025-03-30 00:59:23 +08:00
|
|
|
|
def __enter__(self):
|
|
|
|
|
self.reset()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
print(self)
|
|
|
|
|
|
2025-03-28 01:25:15 +08:00
|
|
|
|
def reset(self):
|
|
|
|
|
self.prompt_tokens = 0
|
|
|
|
|
self.completion_tokens = 0
|
|
|
|
|
self.total_tokens = 0
|
|
|
|
|
self.call_count = 0
|
|
|
|
|
|
|
|
|
|
def add_usage(self, token_counts):
|
|
|
|
|
"""Add token usage from one LLM call.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
|
|
|
|
"""
|
|
|
|
|
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
|
|
|
|
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
|
|
|
|
|
|
|
|
|
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
|
|
|
|
if "total_tokens" in token_counts:
|
|
|
|
|
self.total_tokens += token_counts["total_tokens"]
|
|
|
|
|
else:
|
|
|
|
|
self.total_tokens += token_counts.get(
|
|
|
|
|
"prompt_tokens", 0
|
|
|
|
|
) + token_counts.get("completion_tokens", 0)
|
|
|
|
|
|
|
|
|
|
self.call_count += 1
|
|
|
|
|
|
|
|
|
|
def get_usage(self):
|
|
|
|
|
"""Get current usage statistics."""
|
|
|
|
|
return {
|
|
|
|
|
"prompt_tokens": self.prompt_tokens,
|
|
|
|
|
"completion_tokens": self.completion_tokens,
|
|
|
|
|
"total_tokens": self.total_tokens,
|
|
|
|
|
"call_count": self.call_count,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
usage = self.get_usage()
|
|
|
|
|
return (
|
|
|
|
|
f"LLM call count: {usage['call_count']}, "
|
|
|
|
|
f"Prompt tokens: {usage['prompt_tokens']}, "
|
|
|
|
|
f"Completion tokens: {usage['completion_tokens']}, "
|
|
|
|
|
f"Total tokens: {usage['total_tokens']}"
|
|
|
|
|
)
|