2025-02-15 22:37:12 +01:00
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
import html
|
2024-10-31 14:31:26 +08:00
|
|
|
|
import io
|
|
|
|
|
import csv
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import json
|
|
|
|
|
import logging
|
2025-03-03 23:18:41 +08:00
|
|
|
|
import logging.handlers
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from functools import wraps
|
|
|
|
|
from hashlib import md5
|
2025-04-10 03:57:36 +08:00
|
|
|
|
from typing import Any, Callable, TYPE_CHECKING
|
2024-10-20 23:08:26 +08:00
|
|
|
|
import xml.etree.ElementTree as ET
|
2024-10-10 15:02:30 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
import tiktoken
|
2024-12-08 17:35:52 +08:00
|
|
|
|
from lightrag.prompt import PROMPTS
|
2025-02-22 13:25:12 +08:00
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
# Use TYPE_CHECKING to avoid circular imports
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from lightrag.base import BaseKVStorage
|
|
|
|
|
|
2025-03-29 03:37:23 +08:00
|
|
|
|
# use the .env that is inside the current folder
|
|
|
|
|
# allows to use different .env file for each lightrag instance
|
|
|
|
|
# the OS environment variables take precedence over the .env file
|
|
|
|
|
load_dotenv(dotenv_path=".env", override=False)
|
2025-02-18 17:11:17 +01:00
|
|
|
|
|
2025-02-17 11:37:38 +08:00
|
|
|
|
VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true"
|
2025-02-17 01:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verbose_debug(msg: str, *args, **kwargs):
|
|
|
|
|
"""Function for outputting detailed debug information.
|
|
|
|
|
When VERBOSE_DEBUG=True, outputs the complete message.
|
2025-02-21 16:28:08 +08:00
|
|
|
|
When VERBOSE_DEBUG=False, outputs only the first 50 characters.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
msg: The message format string
|
|
|
|
|
*args: Arguments to be formatted into the message
|
|
|
|
|
**kwargs: Keyword arguments passed to logger.debug()
|
2025-02-17 01:38:18 +08:00
|
|
|
|
"""
|
|
|
|
|
if VERBOSE_DEBUG:
|
|
|
|
|
logger.debug(msg, *args, **kwargs)
|
2025-02-21 16:28:08 +08:00
|
|
|
|
else:
|
|
|
|
|
# Format the message with args first
|
|
|
|
|
if args:
|
|
|
|
|
formatted_msg = msg % args
|
|
|
|
|
else:
|
|
|
|
|
formatted_msg = msg
|
|
|
|
|
# Then truncate the formatted message
|
|
|
|
|
truncated_msg = (
|
2025-03-28 21:39:04 +08:00
|
|
|
|
formatted_msg[:100] + "..." if len(formatted_msg) > 100 else formatted_msg
|
2025-02-21 16:28:08 +08:00
|
|
|
|
)
|
|
|
|
|
logger.debug(truncated_msg, **kwargs)
|
2025-02-17 01:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_verbose_debug(enabled: bool):
|
|
|
|
|
"""Enable or disable verbose debug output"""
|
|
|
|
|
global VERBOSE_DEBUG
|
|
|
|
|
VERBOSE_DEBUG = enabled
|
|
|
|
|
|
2025-02-26 18:11:16 +08:00
|
|
|
|
|
2025-01-16 12:52:37 +08:00
|
|
|
|
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
|
|
|
|
|
2025-02-26 12:23:35 +08:00
|
|
|
|
# Initialize logger
|
2024-10-10 15:02:30 +08:00
|
|
|
|
logger = logging.getLogger("lightrag")
|
2025-02-26 12:23:35 +08:00
|
|
|
|
logger.propagate = False # prevent log message send to root loggger
|
|
|
|
|
# Let the main application configure the handlers
|
|
|
|
|
logger.setLevel(logging.INFO)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2025-01-16 12:52:37 +08:00
|
|
|
|
# Set httpx logging level to WARNING
|
|
|
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-03-03 23:18:41 +08:00
|
|
|
|
class LightragPathFilter(logging.Filter):
|
|
|
|
|
"""Filter for lightrag logger to filter out frequent path access logs"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
# Define paths to be filtered
|
2025-03-26 23:53:41 +08:00
|
|
|
|
self.filtered_paths = [
|
|
|
|
|
"/documents",
|
|
|
|
|
"/health",
|
|
|
|
|
"/webui/",
|
|
|
|
|
"/documents/pipeline_status",
|
|
|
|
|
]
|
2025-03-13 17:45:56 +08:00
|
|
|
|
# self.filtered_paths = ["/health", "/webui/"]
|
2025-03-03 23:18:41 +08:00
|
|
|
|
|
|
|
|
|
def filter(self, record):
|
|
|
|
|
try:
|
|
|
|
|
# Check if record has the required attributes for an access log
|
|
|
|
|
if not hasattr(record, "args") or not isinstance(record.args, tuple):
|
|
|
|
|
return True
|
|
|
|
|
if len(record.args) < 5:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# Extract method, path and status from the record args
|
|
|
|
|
method = record.args[1]
|
|
|
|
|
path = record.args[2]
|
|
|
|
|
status = record.args[4]
|
|
|
|
|
|
|
|
|
|
# Filter out successful GET requests to filtered paths
|
|
|
|
|
if (
|
|
|
|
|
method == "GET"
|
|
|
|
|
and (status == 200 or status == 304)
|
|
|
|
|
and path in self.filtered_paths
|
|
|
|
|
):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
except Exception:
|
|
|
|
|
# In case of any error, let the message through
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_logger(
|
|
|
|
|
logger_name: str,
|
|
|
|
|
level: str = "INFO",
|
|
|
|
|
add_filter: bool = False,
|
2025-03-18 12:08:42 +01:00
|
|
|
|
log_file_path: str | None = None,
|
|
|
|
|
enable_file_logging: bool = True,
|
2025-03-03 23:18:41 +08:00
|
|
|
|
):
|
2025-03-18 12:08:42 +01:00
|
|
|
|
"""Set up a logger with console and optionally file handlers
|
2025-03-03 23:18:41 +08:00
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
logger_name: Name of the logger to set up
|
|
|
|
|
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
|
|
|
add_filter: Whether to add LightragPathFilter to the logger
|
2025-03-18 12:08:42 +01:00
|
|
|
|
log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd
|
|
|
|
|
enable_file_logging: Whether to enable logging to a file (defaults to True)
|
2025-03-03 23:18:41 +08:00
|
|
|
|
"""
|
|
|
|
|
# Configure formatters
|
|
|
|
|
detailed_formatter = logging.Formatter(
|
|
|
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
|
)
|
|
|
|
|
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
|
|
|
|
|
|
|
|
|
logger_instance = logging.getLogger(logger_name)
|
|
|
|
|
logger_instance.setLevel(level)
|
|
|
|
|
logger_instance.handlers = [] # Clear existing handlers
|
|
|
|
|
logger_instance.propagate = False
|
|
|
|
|
|
|
|
|
|
# Add console handler
|
|
|
|
|
console_handler = logging.StreamHandler()
|
|
|
|
|
console_handler.setFormatter(simple_formatter)
|
|
|
|
|
console_handler.setLevel(level)
|
|
|
|
|
logger_instance.addHandler(console_handler)
|
|
|
|
|
|
2025-03-18 12:08:42 +01:00
|
|
|
|
# Add file handler by default unless explicitly disabled
|
|
|
|
|
if enable_file_logging:
|
|
|
|
|
# Get log file path
|
|
|
|
|
if log_file_path is None:
|
|
|
|
|
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
|
|
|
|
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
|
|
|
|
|
|
|
|
|
# Ensure log directory exists
|
|
|
|
|
os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Get log file max size and backup count from environment variables
|
|
|
|
|
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
|
|
|
|
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Add file handler
|
|
|
|
|
file_handler = logging.handlers.RotatingFileHandler(
|
|
|
|
|
filename=log_file_path,
|
|
|
|
|
maxBytes=log_max_bytes,
|
|
|
|
|
backupCount=log_backup_count,
|
|
|
|
|
encoding="utf-8",
|
|
|
|
|
)
|
|
|
|
|
file_handler.setFormatter(detailed_formatter)
|
|
|
|
|
file_handler.setLevel(level)
|
|
|
|
|
logger_instance.addHandler(file_handler)
|
|
|
|
|
except PermissionError as e:
|
|
|
|
|
logger.warning(f"Could not create log file at {log_file_path}: {str(e)}")
|
|
|
|
|
logger.warning("Continuing with console logging only")
|
2025-03-03 23:18:41 +08:00
|
|
|
|
|
|
|
|
|
# Add path filter if requested
|
|
|
|
|
if add_filter:
|
|
|
|
|
path_filter = LightragPathFilter()
|
|
|
|
|
logger_instance.addFilter(path_filter)
|
|
|
|
|
|
|
|
|
|
|
2025-02-26 12:23:35 +08:00
|
|
|
|
class UnlimitedSemaphore:
|
|
|
|
|
"""A context manager that allows unlimited access."""
|
|
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ENCODER = None
|
|
|
|
|
|
2025-02-26 18:11:16 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
class EmbeddingFunc:
|
|
|
|
|
embedding_dim: int
|
|
|
|
|
max_token_size: int
|
|
|
|
|
func: callable
|
2025-02-01 22:07:12 +08:00
|
|
|
|
# concurrent_limit: int = 16
|
2024-12-10 09:01:21 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
2025-02-01 09:53:11 +08:00
|
|
|
|
return await self.func(*args, **kwargs)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
|
def locate_json_string_body_from_string(content: str) -> str | None:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
"""Locate the JSON string body from a string"""
|
2024-11-25 13:29:55 +08:00
|
|
|
|
try:
|
|
|
|
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
|
|
if maybe_json_str is not None:
|
|
|
|
|
maybe_json_str = maybe_json_str.group(0)
|
|
|
|
|
maybe_json_str = maybe_json_str.replace("\\n", "")
|
|
|
|
|
maybe_json_str = maybe_json_str.replace("\n", "")
|
|
|
|
|
maybe_json_str = maybe_json_str.replace("'", '"')
|
2024-11-29 21:41:37 +08:00
|
|
|
|
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
2024-11-25 13:29:55 +08:00
|
|
|
|
return maybe_json_str
|
2024-11-25 13:40:38 +08:00
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# try:
|
|
|
|
|
# content = (
|
|
|
|
|
# content.replace(kw_prompt[:-1], "")
|
|
|
|
|
# .replace("user", "")
|
|
|
|
|
# .replace("model", "")
|
|
|
|
|
# .strip()
|
|
|
|
|
# )
|
2024-11-25 13:40:38 +08:00
|
|
|
|
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
|
2024-11-25 13:29:55 +08:00
|
|
|
|
# json.loads(maybe_json_str)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return None
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-09 11:46:01 +01:00
|
|
|
|
def convert_response_to_json(response: str) -> dict[str, Any]:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
json_str = locate_json_string_body_from_string(response)
|
|
|
|
|
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
|
|
|
|
try:
|
|
|
|
|
data = json.loads(json_str)
|
|
|
|
|
return data
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
logger.error(f"Failed to parse JSON: {json_str}")
|
|
|
|
|
raise e from None
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
2025-01-24 18:59:24 +08:00
|
|
|
|
"""Compute a hash for the given arguments.
|
|
|
|
|
Args:
|
|
|
|
|
*args: Arguments to hash
|
2025-02-01 23:05:02 +08:00
|
|
|
|
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
2025-01-24 18:59:24 +08:00
|
|
|
|
Returns:
|
|
|
|
|
str: Hash string
|
|
|
|
|
"""
|
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
|
|
# Convert all arguments to strings and join them
|
|
|
|
|
args_str = "".join([str(arg) for arg in args])
|
|
|
|
|
if cache_type:
|
|
|
|
|
args_str = f"{cache_type}:{args_str}"
|
|
|
|
|
|
|
|
|
|
# Compute MD5 hash
|
|
|
|
|
return hashlib.md5(args_str.encode()).hexdigest()
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-14 23:31:27 +01:00
|
|
|
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
|
|
|
|
"""
|
|
|
|
|
Compute a unique ID for a given content string.
|
|
|
|
|
|
|
|
|
|
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
|
|
|
|
"""
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return prefix + md5(content.encode()).hexdigest()
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-01 10:36:15 +08:00
|
|
|
|
def limit_async_func_call(max_size: int):
|
|
|
|
|
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
def final_decro(func):
|
2025-02-01 10:36:15 +08:00
|
|
|
|
sem = asyncio.Semaphore(max_size)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
|
async def wait_func(*args, **kwargs):
|
2025-02-01 10:36:15 +08:00
|
|
|
|
async with sem:
|
|
|
|
|
result = await func(*args, **kwargs)
|
|
|
|
|
return result
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
|
|
return wait_func
|
|
|
|
|
|
|
|
|
|
return final_decro
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def wrap_embedding_func_with_attrs(**kwargs):
|
|
|
|
|
"""Wrap a function with attributes"""
|
|
|
|
|
|
|
|
|
|
def final_decro(func) -> EmbeddingFunc:
|
|
|
|
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
|
|
|
|
return new_func
|
|
|
|
|
|
|
|
|
|
return final_decro
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def load_json(file_name):
|
|
|
|
|
if not os.path.exists(file_name):
|
|
|
|
|
return None
|
2024-10-11 11:24:42 +08:00
|
|
|
|
with open(file_name, encoding="utf-8") as f:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return json.load(f)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def write_json(json_obj, file_name):
|
2024-10-11 11:24:42 +08:00
|
|
|
|
with open(file_name, "w", encoding="utf-8") as f:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
|
|
|
|
global ENCODER
|
|
|
|
|
if ENCODER is None:
|
|
|
|
|
ENCODER = tiktoken.encoding_for_model(model_name)
|
|
|
|
|
tokens = ENCODER.encode(content)
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|
|
|
|
global ENCODER
|
|
|
|
|
if ENCODER is None:
|
|
|
|
|
ENCODER = tiktoken.encoding_for_model(model_name)
|
|
|
|
|
content = ENCODER.decode(tokens)
|
|
|
|
|
return content
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def pack_user_ass_to_openai_messages(*args: str):
|
|
|
|
|
roles = ["user", "assistant"]
|
|
|
|
|
return [
|
|
|
|
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
|
|
|
|
]
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
|
|
|
|
"""Split a string by multiple markers"""
|
|
|
|
|
if not markers:
|
|
|
|
|
return [content]
|
2025-04-09 11:30:29 +08:00
|
|
|
|
content = content if content is not None else ""
|
2024-10-10 15:02:30 +08:00
|
|
|
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
|
|
|
|
return [r.strip() for r in results if r.strip()]
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
# Refer the utils functions of the official GraphRAG implementation:
|
|
|
|
|
# https://github.com/microsoft/graphrag
|
|
|
|
|
def clean_str(input: Any) -> str:
|
|
|
|
|
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
|
|
|
|
# If we get non-string input, just give it back
|
|
|
|
|
if not isinstance(input, str):
|
|
|
|
|
return input
|
|
|
|
|
|
|
|
|
|
result = html.unescape(input.strip())
|
|
|
|
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
|
|
|
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def is_float_regex(value: str) -> bool:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def truncate_list_by_token_size(
|
|
|
|
|
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
|
|
|
|
) -> list[int]:
|
2024-10-10 15:02:30 +08:00
|
|
|
|
"""Truncate a list of data by token size"""
|
|
|
|
|
if max_token_size <= 0:
|
|
|
|
|
return []
|
|
|
|
|
tokens = 0
|
|
|
|
|
for i, data in enumerate(list_data):
|
|
|
|
|
tokens += len(encode_string_by_tiktoken(key(data)))
|
|
|
|
|
if tokens > max_token_size:
|
|
|
|
|
return list_data[:i]
|
|
|
|
|
return list_data
|
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
|
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
2024-10-31 14:31:26 +08:00
|
|
|
|
output = io.StringIO()
|
2025-01-27 10:13:06 +01:00
|
|
|
|
writer = csv.writer(
|
|
|
|
|
output,
|
2025-01-27 23:21:34 +08:00
|
|
|
|
quoting=csv.QUOTE_ALL, # Quote all fields
|
|
|
|
|
escapechar="\\", # Use backslash as escape character
|
|
|
|
|
quotechar='"', # Use double quotes
|
|
|
|
|
lineterminator="\n", # Explicit line terminator
|
2025-01-27 10:13:06 +01:00
|
|
|
|
)
|
2024-10-31 14:31:26 +08:00
|
|
|
|
writer.writerows(data)
|
|
|
|
|
return output.getvalue()
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
|
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
2025-01-27 10:15:30 +01:00
|
|
|
|
# Clean the string by removing NUL characters
|
2025-01-27 23:21:34 +08:00
|
|
|
|
cleaned_string = csv_string.replace("\0", "")
|
|
|
|
|
|
2025-01-27 10:15:30 +01:00
|
|
|
|
output = io.StringIO(cleaned_string)
|
|
|
|
|
reader = csv.reader(
|
|
|
|
|
output,
|
2025-01-27 23:21:34 +08:00
|
|
|
|
quoting=csv.QUOTE_ALL, # Match the writer configuration
|
|
|
|
|
escapechar="\\", # Use backslash as escape character
|
|
|
|
|
quotechar='"', # Use double quotes
|
2025-01-27 10:15:30 +01:00
|
|
|
|
)
|
2025-01-27 23:21:34 +08:00
|
|
|
|
|
2025-01-27 10:15:30 +01:00
|
|
|
|
try:
|
|
|
|
|
return [row for row in reader]
|
|
|
|
|
except csv.Error as e:
|
|
|
|
|
raise ValueError(f"Failed to parse CSV string: {str(e)}")
|
|
|
|
|
finally:
|
|
|
|
|
output.close()
|
2024-10-31 14:31:26 +08:00
|
|
|
|
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
|
def save_data_to_file(data, file_name):
|
2024-10-19 09:43:17 +05:30
|
|
|
|
with open(file_name, "w", encoding="utf-8") as f:
|
|
|
|
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
2024-10-20 23:08:26 +08:00
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
|
|
2024-10-20 23:08:26 +08:00
|
|
|
|
def xml_to_json(xml_file):
|
|
|
|
|
try:
|
|
|
|
|
tree = ET.parse(xml_file)
|
|
|
|
|
root = tree.getroot()
|
|
|
|
|
|
|
|
|
|
# Print the root element's tag and attributes to confirm the file has been correctly loaded
|
|
|
|
|
print(f"Root element: {root.tag}")
|
|
|
|
|
print(f"Root attributes: {root.attrib}")
|
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
|
data = {"nodes": [], "edges": []}
|
2024-10-20 23:08:26 +08:00
|
|
|
|
|
|
|
|
|
# Use namespace
|
2024-10-25 13:32:25 +05:30
|
|
|
|
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
2024-10-20 23:08:26 +08:00
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
|
for node in root.findall(".//node", namespace):
|
2024-10-20 23:08:26 +08:00
|
|
|
|
node_data = {
|
2024-10-25 13:32:25 +05:30
|
|
|
|
"id": node.get("id").strip('"'),
|
|
|
|
|
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
|
|
|
|
if node.find("./data[@key='d0']", namespace) is not None
|
|
|
|
|
else "",
|
|
|
|
|
"description": node.find("./data[@key='d1']", namespace).text
|
|
|
|
|
if node.find("./data[@key='d1']", namespace) is not None
|
|
|
|
|
else "",
|
|
|
|
|
"source_id": node.find("./data[@key='d2']", namespace).text
|
|
|
|
|
if node.find("./data[@key='d2']", namespace) is not None
|
|
|
|
|
else "",
|
2024-10-20 23:08:26 +08:00
|
|
|
|
}
|
|
|
|
|
data["nodes"].append(node_data)
|
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
|
for edge in root.findall(".//edge", namespace):
|
2024-10-20 23:08:26 +08:00
|
|
|
|
edge_data = {
|
2024-10-25 13:32:25 +05:30
|
|
|
|
"source": edge.get("source").strip('"'),
|
|
|
|
|
"target": edge.get("target").strip('"'),
|
|
|
|
|
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
|
|
|
|
if edge.find("./data[@key='d3']", namespace) is not None
|
|
|
|
|
else 0.0,
|
|
|
|
|
"description": edge.find("./data[@key='d4']", namespace).text
|
|
|
|
|
if edge.find("./data[@key='d4']", namespace) is not None
|
|
|
|
|
else "",
|
|
|
|
|
"keywords": edge.find("./data[@key='d5']", namespace).text
|
|
|
|
|
if edge.find("./data[@key='d5']", namespace) is not None
|
|
|
|
|
else "",
|
|
|
|
|
"source_id": edge.find("./data[@key='d6']", namespace).text
|
|
|
|
|
if edge.find("./data[@key='d6']", namespace) is not None
|
|
|
|
|
else "",
|
2024-10-20 23:08:26 +08:00
|
|
|
|
}
|
|
|
|
|
data["edges"].append(edge_data)
|
|
|
|
|
|
|
|
|
|
# Print the number of nodes and edges found
|
|
|
|
|
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
except ET.ParseError as e:
|
|
|
|
|
print(f"Error parsing XML file: {e}")
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"An error occurred: {e}")
|
|
|
|
|
return None
|
2024-10-31 11:34:01 +08:00
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def process_combine_contexts(hl: str, ll: str):
|
2024-10-31 11:34:01 +08:00
|
|
|
|
header = None
|
2024-10-31 14:31:26 +08:00
|
|
|
|
list_hl = csv_string_to_list(hl.strip())
|
|
|
|
|
list_ll = csv_string_to_list(ll.strip())
|
2024-11-03 17:53:53 +08:00
|
|
|
|
|
2024-10-31 11:34:01 +08:00
|
|
|
|
if list_hl:
|
2024-11-06 11:18:14 -05:00
|
|
|
|
header = list_hl[0]
|
2024-10-31 11:34:01 +08:00
|
|
|
|
list_hl = list_hl[1:]
|
|
|
|
|
if list_ll:
|
|
|
|
|
header = list_ll[0]
|
|
|
|
|
list_ll = list_ll[1:]
|
|
|
|
|
if header is None:
|
|
|
|
|
return ""
|
2024-11-03 17:53:53 +08:00
|
|
|
|
|
2024-10-31 11:34:01 +08:00
|
|
|
|
if list_hl:
|
2024-11-06 11:18:14 -05:00
|
|
|
|
list_hl = [",".join(item[1:]) for item in list_hl if item]
|
2024-10-31 11:34:01 +08:00
|
|
|
|
if list_ll:
|
2024-11-06 11:18:14 -05:00
|
|
|
|
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
2024-10-31 11:34:01 +08:00
|
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
|
combined_sources = []
|
|
|
|
|
seen = set()
|
2024-10-31 11:34:01 +08:00
|
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
|
for item in list_hl + list_ll:
|
|
|
|
|
if item and item not in seen:
|
|
|
|
|
combined_sources.append(item)
|
|
|
|
|
seen.add(item)
|
2024-11-03 17:53:53 +08:00
|
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
|
combined_sources_result = [",\t".join(header)]
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
|
for i, item in enumerate(combined_sources, start=1):
|
|
|
|
|
combined_sources_result.append(f"{i},\t{item}")
|
2024-10-31 11:34:01 +08:00
|
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
|
combined_sources_result = "\n".join(combined_sources_result)
|
|
|
|
|
|
|
|
|
|
return combined_sources_result
|
2024-12-06 08:17:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_best_cached_response(
|
2024-12-06 14:29:16 +08:00
|
|
|
|
hashing_kv,
|
|
|
|
|
current_embedding,
|
|
|
|
|
similarity_threshold=0.95,
|
|
|
|
|
mode="default",
|
2024-12-08 17:35:52 +08:00
|
|
|
|
use_llm_check=False,
|
|
|
|
|
llm_func=None,
|
|
|
|
|
original_prompt=None,
|
2025-01-24 18:59:24 +08:00
|
|
|
|
cache_type=None,
|
2025-02-15 22:37:12 +01:00
|
|
|
|
) -> str | None:
|
2025-02-02 04:27:55 +08:00
|
|
|
|
logger.debug(
|
|
|
|
|
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
|
|
|
|
)
|
2024-12-06 14:29:16 +08:00
|
|
|
|
mode_cache = await hashing_kv.get_by_id(mode)
|
|
|
|
|
if not mode_cache:
|
2024-12-06 08:17:20 +08:00
|
|
|
|
return None
|
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
|
best_similarity = -1
|
|
|
|
|
best_response = None
|
|
|
|
|
best_prompt = None
|
|
|
|
|
best_cache_id = None
|
|
|
|
|
|
|
|
|
|
# Only iterate through cache entries for this mode
|
|
|
|
|
for cache_id, cache_data in mode_cache.items():
|
2025-01-24 18:59:24 +08:00
|
|
|
|
# Skip if cache_type doesn't match
|
|
|
|
|
if cache_type and cache_data.get("cache_type") != cache_type:
|
|
|
|
|
continue
|
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
|
if cache_data["embedding"] is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Convert cached embedding list to ndarray
|
|
|
|
|
cached_quantized = np.frombuffer(
|
|
|
|
|
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
|
|
|
|
).reshape(cache_data["embedding_shape"])
|
|
|
|
|
cached_embedding = dequantize_embedding(
|
|
|
|
|
cached_quantized,
|
|
|
|
|
cache_data["embedding_min"],
|
|
|
|
|
cache_data["embedding_max"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
similarity = cosine_similarity(current_embedding, cached_embedding)
|
|
|
|
|
if similarity > best_similarity:
|
|
|
|
|
best_similarity = similarity
|
|
|
|
|
best_response = cache_data["return"]
|
|
|
|
|
best_prompt = cache_data["original_prompt"]
|
|
|
|
|
best_cache_id = cache_id
|
|
|
|
|
|
|
|
|
|
if best_similarity > similarity_threshold:
|
2024-12-08 17:35:52 +08:00
|
|
|
|
# If LLM check is enabled and all required parameters are provided
|
2025-02-13 14:07:36 +08:00
|
|
|
|
if (
|
|
|
|
|
use_llm_check
|
|
|
|
|
and llm_func
|
|
|
|
|
and original_prompt
|
|
|
|
|
and best_prompt
|
|
|
|
|
and best_response is not None
|
|
|
|
|
):
|
2024-12-08 17:35:52 +08:00
|
|
|
|
compare_prompt = PROMPTS["similarity_check"].format(
|
|
|
|
|
original_prompt=original_prompt, cached_prompt=best_prompt
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
llm_result = await llm_func(compare_prompt)
|
|
|
|
|
llm_result = llm_result.strip()
|
|
|
|
|
llm_similarity = float(llm_result)
|
|
|
|
|
|
|
|
|
|
# Replace vector similarity with LLM similarity score
|
|
|
|
|
best_similarity = llm_similarity
|
|
|
|
|
if best_similarity < similarity_threshold:
|
|
|
|
|
log_data = {
|
2025-02-13 13:53:52 +08:00
|
|
|
|
"event": "cache_rejected_by_llm",
|
|
|
|
|
"type": cache_type,
|
|
|
|
|
"mode": mode,
|
2024-12-08 17:35:52 +08:00
|
|
|
|
"original_question": original_prompt[:100] + "..."
|
|
|
|
|
if len(original_prompt) > 100
|
|
|
|
|
else original_prompt,
|
|
|
|
|
"cached_question": best_prompt[:100] + "..."
|
|
|
|
|
if len(best_prompt) > 100
|
|
|
|
|
else best_prompt,
|
|
|
|
|
"similarity_score": round(best_similarity, 4),
|
|
|
|
|
"threshold": similarity_threshold,
|
|
|
|
|
}
|
2025-02-13 13:53:52 +08:00
|
|
|
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
|
|
|
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
2024-12-08 17:35:52 +08:00
|
|
|
|
return None
|
|
|
|
|
except Exception as e: # Catch all possible exceptions
|
|
|
|
|
logger.warning(f"LLM similarity check failed: {e}")
|
|
|
|
|
return None # Return None directly when LLM check fails
|
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
|
prompt_display = (
|
|
|
|
|
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
|
|
|
|
)
|
|
|
|
|
log_data = {
|
|
|
|
|
"event": "cache_hit",
|
2025-02-13 13:53:52 +08:00
|
|
|
|
"type": cache_type,
|
2024-12-06 14:29:16 +08:00
|
|
|
|
"mode": mode,
|
|
|
|
|
"similarity": round(best_similarity, 4),
|
|
|
|
|
"cache_id": best_cache_id,
|
|
|
|
|
"original_prompt": prompt_display,
|
|
|
|
|
}
|
2025-02-13 13:53:52 +08:00
|
|
|
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
2024-12-06 14:29:16 +08:00
|
|
|
|
return best_response
|
|
|
|
|
return None
|
2024-12-06 08:17:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cosine_similarity(v1, v2):
|
|
|
|
|
"""Calculate cosine similarity between two vectors"""
|
|
|
|
|
dot_product = np.dot(v1, v2)
|
|
|
|
|
norm1 = np.linalg.norm(v1)
|
|
|
|
|
norm2 = np.linalg.norm(v2)
|
|
|
|
|
return dot_product / (norm1 * norm2)
|
|
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:32 +01:00
|
|
|
|
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
2024-12-06 08:17:20 +08:00
|
|
|
|
"""Quantize embedding to specified bits"""
|
2025-02-01 09:55:05 +08:00
|
|
|
|
# Convert list to numpy array if needed
|
|
|
|
|
if isinstance(embedding, list):
|
|
|
|
|
embedding = np.array(embedding)
|
2025-02-01 15:22:40 +08:00
|
|
|
|
|
2024-12-06 08:17:20 +08:00
|
|
|
|
# Calculate min/max values for reconstruction
|
|
|
|
|
min_val = embedding.min()
|
|
|
|
|
max_val = embedding.max()
|
|
|
|
|
|
|
|
|
|
# Quantize to 0-255 range
|
|
|
|
|
scale = (2**bits - 1) / (max_val - min_val)
|
|
|
|
|
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
return quantized, min_val, max_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dequantize_embedding(
|
|
|
|
|
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
|
|
|
|
) -> np.ndarray:
|
|
|
|
|
"""Restore quantized embedding"""
|
|
|
|
|
scale = (max_val - min_val) / (2**bits - 1)
|
|
|
|
|
return (quantized * scale + min_val).astype(np.float32)
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
2024-12-08 17:35:52 +08:00
|
|
|
|
|
2025-01-31 15:33:50 +08:00
|
|
|
|
async def handle_cache(
|
2025-02-02 01:56:32 +08:00
|
|
|
|
hashing_kv,
|
|
|
|
|
args_hash,
|
|
|
|
|
prompt,
|
|
|
|
|
mode="default",
|
|
|
|
|
cache_type=None,
|
2025-01-31 15:33:50 +08:00
|
|
|
|
):
|
2024-12-08 10:37:55 +08:00
|
|
|
|
"""Generic cache handling function"""
|
2025-03-09 22:15:26 +08:00
|
|
|
|
if hashing_kv is None:
|
2024-12-08 10:37:55 +08:00
|
|
|
|
return None, None, None, None
|
|
|
|
|
|
2025-03-09 22:15:26 +08:00
|
|
|
|
if mode != "default": # handle cache for all type of query
|
|
|
|
|
if not hashing_kv.global_config.get("enable_llm_cache"):
|
|
|
|
|
return None, None, None, None
|
|
|
|
|
|
2025-02-02 00:10:21 +08:00
|
|
|
|
# Get embedding cache configuration
|
|
|
|
|
embedding_cache_config = hashing_kv.global_config.get(
|
|
|
|
|
"embedding_cache_config",
|
|
|
|
|
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
2025-02-01 22:12:45 +08:00
|
|
|
|
)
|
2025-02-02 00:10:21 +08:00
|
|
|
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
|
|
|
|
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
|
|
|
|
|
|
|
|
|
quantized = min_val = max_val = None
|
2025-03-09 22:15:26 +08:00
|
|
|
|
if is_embedding_cache_enabled: # Use embedding simularity to match cache
|
2025-02-02 00:10:21 +08:00
|
|
|
|
current_embedding = await hashing_kv.embedding_func([prompt])
|
2025-02-02 04:27:55 +08:00
|
|
|
|
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
|
|
|
best_cached_response = await get_best_cached_response(
|
|
|
|
|
hashing_kv,
|
|
|
|
|
current_embedding[0],
|
|
|
|
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
|
|
|
mode=mode,
|
|
|
|
|
use_llm_check=use_llm_check,
|
2025-02-02 01:28:46 +08:00
|
|
|
|
llm_func=llm_model_func if use_llm_check else None,
|
2025-02-02 03:09:06 +08:00
|
|
|
|
original_prompt=prompt,
|
2025-02-02 00:10:21 +08:00
|
|
|
|
cache_type=cache_type,
|
|
|
|
|
)
|
|
|
|
|
if best_cached_response is not None:
|
2025-03-09 00:59:40 +08:00
|
|
|
|
logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
return best_cached_response, None, None, None
|
|
|
|
|
else:
|
2025-02-13 13:53:52 +08:00
|
|
|
|
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
2025-03-09 00:59:40 +08:00
|
|
|
|
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
return None, quantized, min_val, max_val
|
|
|
|
|
|
2025-03-09 22:15:26 +08:00
|
|
|
|
else: # handle cache for entity extraction
|
|
|
|
|
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
|
|
|
|
return None, None, None, None
|
|
|
|
|
|
|
|
|
|
# Here is the conditions of code reaching this point:
|
|
|
|
|
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
|
|
|
|
|
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
|
2025-02-02 00:10:21 +08:00
|
|
|
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
|
|
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
2024-12-08 10:37:55 +08:00
|
|
|
|
else:
|
2025-02-02 00:10:21 +08:00
|
|
|
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
|
|
|
|
if args_hash in mode_cache:
|
2025-03-09 00:59:40 +08:00
|
|
|
|
logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
return mode_cache[args_hash]["return"], None, None, None
|
|
|
|
|
|
2025-03-09 00:59:40 +08:00
|
|
|
|
logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
|
return None, None, None, None
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class CacheData:
|
|
|
|
|
args_hash: str
|
|
|
|
|
content: str
|
|
|
|
|
prompt: str
|
2025-02-15 22:37:12 +01:00
|
|
|
|
quantized: np.ndarray | None = None
|
|
|
|
|
min_val: float | None = None
|
|
|
|
|
max_val: float | None = None
|
2024-12-08 10:37:55 +08:00
|
|
|
|
mode: str = "default"
|
2025-01-25 00:55:07 +01:00
|
|
|
|
cache_type: str = "query"
|
|
|
|
|
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
|
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
2025-03-10 01:45:58 +08:00
|
|
|
|
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
Args:
|
|
|
|
|
hashing_kv: The key-value storage for caching
|
|
|
|
|
cache_data: The cache data to save
|
|
|
|
|
"""
|
|
|
|
|
# Skip if storage is None or content is a streaming response
|
|
|
|
|
if hashing_kv is None or not cache_data.content:
|
2024-12-08 10:37:55 +08:00
|
|
|
|
return
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# If content is a streaming response, don't cache it
|
|
|
|
|
if hasattr(cache_data.content, "__aiter__"):
|
|
|
|
|
logger.debug("Streaming response detected, skipping cache")
|
|
|
|
|
return
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Get existing cache data
|
2025-01-06 12:50:05 +08:00
|
|
|
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
|
|
|
mode_cache = (
|
|
|
|
|
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
|
|
|
|
or {}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Check if we already have identical content cached
|
|
|
|
|
if cache_data.args_hash in mode_cache:
|
|
|
|
|
existing_content = mode_cache[cache_data.args_hash].get("return")
|
|
|
|
|
if existing_content == cache_data.content:
|
2025-03-10 02:07:19 +08:00
|
|
|
|
logger.info(
|
|
|
|
|
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
|
|
|
|
)
|
2025-03-10 01:45:58 +08:00
|
|
|
|
return
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Update cache with new content
|
2024-12-08 10:37:55 +08:00
|
|
|
|
mode_cache[cache_data.args_hash] = {
|
|
|
|
|
"return": cache_data.content,
|
2025-02-01 22:27:49 +08:00
|
|
|
|
"cache_type": cache_data.cache_type,
|
2024-12-08 10:37:55 +08:00
|
|
|
|
"embedding": cache_data.quantized.tobytes().hex()
|
|
|
|
|
if cache_data.quantized is not None
|
|
|
|
|
else None,
|
|
|
|
|
"embedding_shape": cache_data.quantized.shape
|
|
|
|
|
if cache_data.quantized is not None
|
|
|
|
|
else None,
|
|
|
|
|
"embedding_min": cache_data.min_val,
|
|
|
|
|
"embedding_max": cache_data.max_val,
|
|
|
|
|
"original_prompt": cache_data.prompt,
|
|
|
|
|
}
|
2025-03-10 02:07:19 +08:00
|
|
|
|
|
2025-04-16 01:24:59 +08:00
|
|
|
|
logger.info(f" == LLM cache == saving {cache_data.mode}: {cache_data.args_hash}")
|
|
|
|
|
|
2025-03-10 01:45:58 +08:00
|
|
|
|
# Only upsert if there's actual new content
|
2024-12-08 10:37:55 +08:00
|
|
|
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
2024-12-09 19:14:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_unicode_decode(content):
|
|
|
|
|
# Regular expression to find all Unicode escape sequences of the form \uXXXX
|
|
|
|
|
unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
|
|
|
|
|
|
|
|
|
|
# Function to replace the Unicode escape with the actual character
|
|
|
|
|
def replace_unicode_escape(match):
|
|
|
|
|
# Convert the matched hexadecimal value into the actual Unicode character
|
|
|
|
|
return chr(int(match.group(1), 16))
|
|
|
|
|
|
|
|
|
|
# Perform the substitution
|
|
|
|
|
decoded_content = unicode_escape_pattern.sub(
|
|
|
|
|
replace_unicode_escape, content.decode("utf-8")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return decoded_content
|
2025-01-06 12:50:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists_func(obj, func_name: str) -> bool:
|
|
|
|
|
"""Check if a function exists in an object or not.
|
|
|
|
|
:param obj:
|
|
|
|
|
:param func_name:
|
|
|
|
|
:return: True / False
|
|
|
|
|
"""
|
|
|
|
|
if callable(getattr(obj, func_name, None)):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return False
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
|
def get_conversation_turns(
|
|
|
|
|
conversation_history: list[dict[str, Any]], num_turns: int
|
|
|
|
|
) -> str:
|
2025-01-24 18:59:24 +08:00
|
|
|
|
"""
|
|
|
|
|
Process conversation history to get the specified number of complete turns.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
conversation_history: List of conversation messages in chronological order
|
|
|
|
|
num_turns: Number of complete turns to include
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Formatted string of the conversation history
|
|
|
|
|
"""
|
2025-02-17 12:32:04 +08:00
|
|
|
|
# Check if num_turns is valid
|
|
|
|
|
if num_turns <= 0:
|
|
|
|
|
return ""
|
|
|
|
|
|
2025-01-24 18:59:24 +08:00
|
|
|
|
# Group messages into turns
|
2025-02-15 00:10:37 +01:00
|
|
|
|
turns: list[list[dict[str, Any]]] = []
|
|
|
|
|
messages: list[dict[str, Any]] = []
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
|
# First, filter out keyword extraction messages
|
|
|
|
|
for msg in conversation_history:
|
|
|
|
|
if msg["role"] == "assistant" and (
|
|
|
|
|
msg["content"].startswith('{ "high_level_keywords"')
|
|
|
|
|
or msg["content"].startswith("{'high_level_keywords'")
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
messages.append(msg)
|
|
|
|
|
|
|
|
|
|
# Then process messages in chronological order
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(messages) - 1:
|
|
|
|
|
msg1 = messages[i]
|
|
|
|
|
msg2 = messages[i + 1]
|
|
|
|
|
|
|
|
|
|
# Check if we have a user-assistant or assistant-user pair
|
|
|
|
|
if (msg1["role"] == "user" and msg2["role"] == "assistant") or (
|
|
|
|
|
msg1["role"] == "assistant" and msg2["role"] == "user"
|
|
|
|
|
):
|
|
|
|
|
# Always put user message first in the turn
|
|
|
|
|
if msg1["role"] == "assistant":
|
|
|
|
|
turn = [msg2, msg1] # user, assistant
|
|
|
|
|
else:
|
|
|
|
|
turn = [msg1, msg2] # user, assistant
|
|
|
|
|
turns.append(turn)
|
2025-01-30 13:08:27 +08:00
|
|
|
|
i += 2
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
|
# Keep only the most recent num_turns
|
|
|
|
|
if len(turns) > num_turns:
|
|
|
|
|
turns = turns[-num_turns:]
|
|
|
|
|
|
|
|
|
|
# Format the turns into a string
|
2025-02-15 00:10:37 +01:00
|
|
|
|
formatted_turns: list[str] = []
|
2025-01-24 18:59:24 +08:00
|
|
|
|
for turn in turns:
|
|
|
|
|
formatted_turns.extend(
|
|
|
|
|
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return "\n".join(formatted_turns)
|
2025-02-20 13:18:17 +01:00
|
|
|
|
|
2025-02-20 13:21:41 +01:00
|
|
|
|
|
2025-02-20 13:18:17 +01:00
|
|
|
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
|
|
|
"""
|
|
|
|
|
Ensure that there is always an event loop available.
|
|
|
|
|
|
|
|
|
|
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
|
|
|
|
it creates a new event loop and sets it as the current event loop.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
asyncio.AbstractEventLoop: The current or newly created event loop.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# Try to get the current event loop
|
|
|
|
|
current_loop = asyncio.get_event_loop()
|
|
|
|
|
if current_loop.is_closed():
|
|
|
|
|
raise RuntimeError("Event loop is closed.")
|
|
|
|
|
return current_loop
|
|
|
|
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
# If no event loop exists or it is closed, create a new one
|
|
|
|
|
logger.info("Creating a new event loop in main thread.")
|
|
|
|
|
new_loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(new_loop)
|
|
|
|
|
return new_loop
|
2025-02-20 13:21:41 +01:00
|
|
|
|
|
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
async def aexport_data(
|
|
|
|
|
chunk_entity_relation_graph,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
output_path: str,
|
|
|
|
|
file_format: str = "csv",
|
|
|
|
|
include_vector_data: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Asynchronously exports all entities, relations, and relationships to various formats.
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
Args:
|
|
|
|
|
chunk_entity_relation_graph: Graph storage instance for entities and relations
|
|
|
|
|
entities_vdb: Vector database storage for entities
|
|
|
|
|
relationships_vdb: Vector database storage for relationships
|
|
|
|
|
output_path: The path to the output file (including extension).
|
|
|
|
|
file_format: Output format - "csv", "excel", "md", "txt".
|
|
|
|
|
- csv: Comma-separated values file
|
|
|
|
|
- excel: Microsoft Excel file with multiple sheets
|
|
|
|
|
- md: Markdown tables
|
|
|
|
|
- txt: Plain text formatted output
|
|
|
|
|
include_vector_data: Whether to include data from the vector database.
|
|
|
|
|
"""
|
|
|
|
|
# Collect data
|
|
|
|
|
entities_data = []
|
|
|
|
|
relations_data = []
|
|
|
|
|
relationships_data = []
|
|
|
|
|
|
|
|
|
|
# --- Entities ---
|
|
|
|
|
all_entities = await chunk_entity_relation_graph.get_all_labels()
|
|
|
|
|
for entity_name in all_entities:
|
|
|
|
|
# Get entity information from graph
|
|
|
|
|
node_data = await chunk_entity_relation_graph.get_node(entity_name)
|
|
|
|
|
source_id = node_data.get("source_id") if node_data else None
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
entity_info = {
|
|
|
|
|
"graph_data": node_data,
|
|
|
|
|
"source_id": source_id,
|
|
|
|
|
}
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
# Optional: Get vector database information
|
|
|
|
|
if include_vector_data:
|
|
|
|
|
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
|
|
|
|
vector_data = await entities_vdb.get_by_id(entity_id)
|
|
|
|
|
entity_info["vector_data"] = vector_data
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
entity_row = {
|
|
|
|
|
"entity_name": entity_name,
|
|
|
|
|
"source_id": source_id,
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"graph_data": str(
|
|
|
|
|
entity_info["graph_data"]
|
|
|
|
|
), # Convert to string to ensure compatibility
|
2025-04-14 03:06:23 +08:00
|
|
|
|
}
|
|
|
|
|
if include_vector_data and "vector_data" in entity_info:
|
|
|
|
|
entity_row["vector_data"] = str(entity_info["vector_data"])
|
|
|
|
|
entities_data.append(entity_row)
|
|
|
|
|
|
|
|
|
|
# --- Relations ---
|
|
|
|
|
for src_entity in all_entities:
|
|
|
|
|
for tgt_entity in all_entities:
|
|
|
|
|
if src_entity == tgt_entity:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
edge_exists = await chunk_entity_relation_graph.has_edge(
|
|
|
|
|
src_entity, tgt_entity
|
|
|
|
|
)
|
|
|
|
|
if edge_exists:
|
|
|
|
|
# Get edge information from graph
|
|
|
|
|
edge_data = await chunk_entity_relation_graph.get_edge(
|
|
|
|
|
src_entity, tgt_entity
|
|
|
|
|
)
|
|
|
|
|
source_id = edge_data.get("source_id") if edge_data else None
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
relation_info = {
|
|
|
|
|
"graph_data": edge_data,
|
|
|
|
|
"source_id": source_id,
|
|
|
|
|
}
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
# Optional: Get vector database information
|
|
|
|
|
if include_vector_data:
|
|
|
|
|
rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-")
|
|
|
|
|
vector_data = await relationships_vdb.get_by_id(rel_id)
|
|
|
|
|
relation_info["vector_data"] = vector_data
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
relation_row = {
|
|
|
|
|
"src_entity": src_entity,
|
|
|
|
|
"tgt_entity": tgt_entity,
|
|
|
|
|
"source_id": relation_info["source_id"],
|
|
|
|
|
"graph_data": str(relation_info["graph_data"]), # Convert to string
|
|
|
|
|
}
|
|
|
|
|
if include_vector_data and "vector_data" in relation_info:
|
|
|
|
|
relation_row["vector_data"] = str(relation_info["vector_data"])
|
|
|
|
|
relations_data.append(relation_row)
|
|
|
|
|
|
|
|
|
|
# --- Relationships (from VectorDB) ---
|
|
|
|
|
all_relationships = await relationships_vdb.client_storage
|
|
|
|
|
for rel in all_relationships["data"]:
|
|
|
|
|
relationships_data.append(
|
|
|
|
|
{
|
|
|
|
|
"relationship_id": rel["__id__"],
|
|
|
|
|
"data": str(rel), # Convert to string for compatibility
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Export based on format
|
|
|
|
|
if file_format == "csv":
|
|
|
|
|
# CSV export
|
|
|
|
|
with open(output_path, "w", newline="", encoding="utf-8") as csvfile:
|
|
|
|
|
# Entities
|
|
|
|
|
if entities_data:
|
|
|
|
|
csvfile.write("# ENTITIES\n")
|
|
|
|
|
writer = csv.DictWriter(csvfile, fieldnames=entities_data[0].keys())
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
writer.writerows(entities_data)
|
|
|
|
|
csvfile.write("\n\n")
|
|
|
|
|
|
|
|
|
|
# Relations
|
|
|
|
|
if relations_data:
|
|
|
|
|
csvfile.write("# RELATIONS\n")
|
2025-04-14 12:08:56 +08:00
|
|
|
|
writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys())
|
2025-04-14 03:06:23 +08:00
|
|
|
|
writer.writeheader()
|
|
|
|
|
writer.writerows(relations_data)
|
|
|
|
|
csvfile.write("\n\n")
|
|
|
|
|
|
|
|
|
|
# Relationships
|
|
|
|
|
if relationships_data:
|
|
|
|
|
csvfile.write("# RELATIONSHIPS\n")
|
|
|
|
|
writer = csv.DictWriter(
|
|
|
|
|
csvfile, fieldnames=relationships_data[0].keys()
|
|
|
|
|
)
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
writer.writerows(relationships_data)
|
|
|
|
|
|
|
|
|
|
elif file_format == "excel":
|
|
|
|
|
# Excel export
|
|
|
|
|
import pandas as pd
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
|
|
|
|
entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame()
|
2025-04-14 03:06:23 +08:00
|
|
|
|
relations_df = (
|
|
|
|
|
pd.DataFrame(relations_data) if relations_data else pd.DataFrame()
|
|
|
|
|
)
|
|
|
|
|
relationships_df = (
|
2025-04-14 12:08:56 +08:00
|
|
|
|
pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame()
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
|
|
|
|
|
if not entities_df.empty:
|
|
|
|
|
entities_df.to_excel(writer, sheet_name="Entities", index=False)
|
|
|
|
|
if not relations_df.empty:
|
|
|
|
|
relations_df.to_excel(writer, sheet_name="Relations", index=False)
|
|
|
|
|
if not relationships_df.empty:
|
|
|
|
|
relationships_df.to_excel(
|
|
|
|
|
writer, sheet_name="Relationships", index=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif file_format == "md":
|
|
|
|
|
# Markdown export
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as mdfile:
|
|
|
|
|
mdfile.write("# LightRAG Data Export\n\n")
|
|
|
|
|
|
|
|
|
|
# Entities
|
|
|
|
|
mdfile.write("## Entities\n\n")
|
|
|
|
|
if entities_data:
|
|
|
|
|
# Write header
|
|
|
|
|
mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n")
|
|
|
|
|
mdfile.write(
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n"
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for entity in entities_data:
|
|
|
|
|
mdfile.write(
|
|
|
|
|
"| " + " | ".join(str(v) for v in entity.values()) + " |\n"
|
|
|
|
|
)
|
|
|
|
|
mdfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
mdfile.write("*No entity data available*\n\n")
|
|
|
|
|
|
|
|
|
|
# Relations
|
|
|
|
|
mdfile.write("## Relations\n\n")
|
|
|
|
|
if relations_data:
|
|
|
|
|
# Write header
|
|
|
|
|
mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n")
|
|
|
|
|
mdfile.write(
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n"
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relation in relations_data:
|
|
|
|
|
mdfile.write(
|
2025-04-14 12:08:56 +08:00
|
|
|
|
"| " + " | ".join(str(v) for v in relation.values()) + " |\n"
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
mdfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
mdfile.write("*No relation data available*\n\n")
|
|
|
|
|
|
|
|
|
|
# Relationships
|
|
|
|
|
mdfile.write("## Relationships\n\n")
|
|
|
|
|
if relationships_data:
|
|
|
|
|
# Write header
|
2025-04-14 12:08:56 +08:00
|
|
|
|
mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n")
|
2025-04-14 03:06:23 +08:00
|
|
|
|
mdfile.write(
|
|
|
|
|
"| "
|
|
|
|
|
+ " | ".join(["---"] * len(relationships_data[0].keys()))
|
|
|
|
|
+ " |\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relationship in relationships_data:
|
|
|
|
|
mdfile.write(
|
|
|
|
|
"| "
|
|
|
|
|
+ " | ".join(str(v) for v in relationship.values())
|
|
|
|
|
+ " |\n"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
mdfile.write("*No relationship data available*\n\n")
|
|
|
|
|
|
|
|
|
|
elif file_format == "txt":
|
|
|
|
|
# Plain text export
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as txtfile:
|
|
|
|
|
txtfile.write("LIGHTRAG DATA EXPORT\n")
|
|
|
|
|
txtfile.write("=" * 80 + "\n\n")
|
|
|
|
|
|
|
|
|
|
# Entities
|
|
|
|
|
txtfile.write("ENTITIES\n")
|
|
|
|
|
txtfile.write("-" * 80 + "\n")
|
|
|
|
|
if entities_data:
|
|
|
|
|
# Create fixed width columns
|
|
|
|
|
col_widths = {
|
|
|
|
|
k: max(len(k), max(len(str(e[k])) for e in entities_data))
|
|
|
|
|
for k in entities_data[0]
|
|
|
|
|
}
|
|
|
|
|
header = " ".join(k.ljust(col_widths[k]) for k in entities_data[0])
|
|
|
|
|
txtfile.write(header + "\n")
|
|
|
|
|
txtfile.write("-" * len(header) + "\n")
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for entity in entities_data:
|
|
|
|
|
row = " ".join(
|
|
|
|
|
str(v).ljust(col_widths[k]) for k, v in entity.items()
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(row + "\n")
|
|
|
|
|
txtfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
txtfile.write("No entity data available\n\n")
|
|
|
|
|
|
|
|
|
|
# Relations
|
|
|
|
|
txtfile.write("RELATIONS\n")
|
|
|
|
|
txtfile.write("-" * 80 + "\n")
|
|
|
|
|
if relations_data:
|
|
|
|
|
# Create fixed width columns
|
|
|
|
|
col_widths = {
|
|
|
|
|
k: max(len(k), max(len(str(r[k])) for r in relations_data))
|
|
|
|
|
for k in relations_data[0]
|
|
|
|
|
}
|
2025-04-14 12:08:56 +08:00
|
|
|
|
header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0])
|
2025-04-14 03:06:23 +08:00
|
|
|
|
txtfile.write(header + "\n")
|
|
|
|
|
txtfile.write("-" * len(header) + "\n")
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relation in relations_data:
|
|
|
|
|
row = " ".join(
|
|
|
|
|
str(v).ljust(col_widths[k]) for k, v in relation.items()
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(row + "\n")
|
|
|
|
|
txtfile.write("\n\n")
|
|
|
|
|
else:
|
|
|
|
|
txtfile.write("No relation data available\n\n")
|
|
|
|
|
|
|
|
|
|
# Relationships
|
|
|
|
|
txtfile.write("RELATIONSHIPS\n")
|
|
|
|
|
txtfile.write("-" * 80 + "\n")
|
|
|
|
|
if relationships_data:
|
|
|
|
|
# Create fixed width columns
|
|
|
|
|
col_widths = {
|
|
|
|
|
k: max(len(k), max(len(str(r[k])) for r in relationships_data))
|
|
|
|
|
for k in relationships_data[0]
|
|
|
|
|
}
|
|
|
|
|
header = " ".join(
|
|
|
|
|
k.ljust(col_widths[k]) for k in relationships_data[0]
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(header + "\n")
|
|
|
|
|
txtfile.write("-" * len(header) + "\n")
|
|
|
|
|
|
|
|
|
|
# Write rows
|
|
|
|
|
for relationship in relationships_data:
|
|
|
|
|
row = " ".join(
|
|
|
|
|
str(v).ljust(col_widths[k]) for k, v in relationship.items()
|
|
|
|
|
)
|
|
|
|
|
txtfile.write(row + "\n")
|
|
|
|
|
else:
|
|
|
|
|
txtfile.write("No relationship data available\n\n")
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unsupported file format: {file_format}. "
|
|
|
|
|
f"Choose from: csv, excel, md, txt"
|
|
|
|
|
)
|
|
|
|
|
if file_format is not None:
|
|
|
|
|
print(f"Data exported to: {output_path} with format: {file_format}")
|
|
|
|
|
else:
|
|
|
|
|
print("Data displayed as table format")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_data(
|
|
|
|
|
chunk_entity_relation_graph,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
output_path: str,
|
|
|
|
|
file_format: str = "csv",
|
|
|
|
|
include_vector_data: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Synchronously exports all entities, relations, and relationships to various formats.
|
2025-04-14 12:08:56 +08:00
|
|
|
|
|
2025-04-14 03:06:23 +08:00
|
|
|
|
Args:
|
|
|
|
|
chunk_entity_relation_graph: Graph storage instance for entities and relations
|
|
|
|
|
entities_vdb: Vector database storage for entities
|
|
|
|
|
relationships_vdb: Vector database storage for relationships
|
|
|
|
|
output_path: The path to the output file (including extension).
|
|
|
|
|
file_format: Output format - "csv", "excel", "md", "txt".
|
|
|
|
|
- csv: Comma-separated values file
|
|
|
|
|
- excel: Microsoft Excel file with multiple sheets
|
|
|
|
|
- md: Markdown tables
|
|
|
|
|
- txt: Plain text formatted output
|
|
|
|
|
include_vector_data: Whether to include data from the vector database.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
|
|
|
|
|
loop.run_until_complete(
|
|
|
|
|
aexport_data(
|
|
|
|
|
chunk_entity_relation_graph,
|
|
|
|
|
entities_vdb,
|
|
|
|
|
relationships_vdb,
|
|
|
|
|
output_path,
|
|
|
|
|
file_format,
|
2025-04-14 12:08:56 +08:00
|
|
|
|
include_vector_data,
|
2025-04-14 03:06:23 +08:00
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-02-20 13:18:17 +01:00
|
|
|
|
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
|
|
|
|
"""Lazily import a class from an external module based on the package of the caller."""
|
|
|
|
|
# Get the caller's module and package
|
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
|
|
caller_frame = inspect.currentframe().f_back
|
|
|
|
|
module = inspect.getmodule(caller_frame)
|
|
|
|
|
package = module.__package__ if module else None
|
|
|
|
|
|
|
|
|
|
def import_class(*args: Any, **kwargs: Any):
|
|
|
|
|
import importlib
|
|
|
|
|
|
|
|
|
|
module = importlib.import_module(module_name, package=package)
|
|
|
|
|
cls = getattr(module, class_name)
|
|
|
|
|
return cls(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return import_class
|
2025-03-11 15:43:04 +08:00
|
|
|
|
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
async def use_llm_func_with_cache(
|
2025-04-10 03:58:04 +08:00
|
|
|
|
input_text: str,
|
2025-04-10 03:57:36 +08:00
|
|
|
|
use_llm_func: callable,
|
2025-04-10 03:58:04 +08:00
|
|
|
|
llm_response_cache: "BaseKVStorage | None" = None,
|
2025-04-10 03:57:36 +08:00
|
|
|
|
max_tokens: int = None,
|
|
|
|
|
history_messages: list[dict[str, str]] = None,
|
2025-04-10 03:58:04 +08:00
|
|
|
|
cache_type: str = "extract",
|
2025-04-10 03:57:36 +08:00
|
|
|
|
) -> str:
|
|
|
|
|
"""Call LLM function with cache support
|
2025-04-10 03:58:04 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
If cache is available and enabled (determined by handle_cache based on mode),
|
|
|
|
|
retrieve result from cache; otherwise call LLM function and save result to cache.
|
2025-04-10 03:58:04 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
Args:
|
|
|
|
|
input_text: Input text to send to LLM
|
|
|
|
|
use_llm_func: LLM function to call
|
|
|
|
|
llm_response_cache: Cache storage instance
|
|
|
|
|
max_tokens: Maximum tokens for generation
|
|
|
|
|
history_messages: History messages list
|
|
|
|
|
cache_type: Type of cache
|
2025-04-10 03:58:04 +08:00
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
Returns:
|
|
|
|
|
LLM response text
|
|
|
|
|
"""
|
|
|
|
|
if llm_response_cache:
|
|
|
|
|
if history_messages:
|
|
|
|
|
history = json.dumps(history_messages, ensure_ascii=False)
|
|
|
|
|
_prompt = history + "\n" + input_text
|
|
|
|
|
else:
|
|
|
|
|
_prompt = input_text
|
|
|
|
|
|
|
|
|
|
arg_hash = compute_args_hash(_prompt)
|
|
|
|
|
cached_return, _1, _2, _3 = await handle_cache(
|
|
|
|
|
llm_response_cache,
|
|
|
|
|
arg_hash,
|
|
|
|
|
_prompt,
|
|
|
|
|
"default",
|
|
|
|
|
cache_type=cache_type,
|
|
|
|
|
)
|
|
|
|
|
if cached_return:
|
|
|
|
|
logger.debug(f"Found cache for {arg_hash}")
|
|
|
|
|
statistic_data["llm_cache"] += 1
|
|
|
|
|
return cached_return
|
|
|
|
|
statistic_data["llm_call"] += 1
|
|
|
|
|
|
|
|
|
|
# Call LLM
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if history_messages:
|
|
|
|
|
kwargs["history_messages"] = history_messages
|
|
|
|
|
if max_tokens is not None:
|
|
|
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
|
|
|
|
|
|
res: str = await use_llm_func(input_text, **kwargs)
|
|
|
|
|
|
2025-04-16 01:24:59 +08:00
|
|
|
|
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
|
|
|
|
|
await save_to_cache(
|
|
|
|
|
llm_response_cache,
|
|
|
|
|
CacheData(
|
|
|
|
|
args_hash=arg_hash,
|
|
|
|
|
content=res,
|
|
|
|
|
prompt=_prompt,
|
|
|
|
|
cache_type=cache_type,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-10 03:57:36 +08:00
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
# When cache is disabled, directly call LLM
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if history_messages:
|
|
|
|
|
kwargs["history_messages"] = history_messages
|
|
|
|
|
if max_tokens is not None:
|
|
|
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
|
|
|
|
|
|
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
|
|
|
|
|
return await use_llm_func(input_text, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
2025-03-25 23:39:09 +08:00
|
|
|
|
def get_content_summary(content: str, max_length: int = 250) -> str:
|
2025-03-11 15:43:04 +08:00
|
|
|
|
"""Get summary of document content
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
content: Original document content
|
|
|
|
|
max_length: Maximum length of summary
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Truncated content with ellipsis if needed
|
|
|
|
|
"""
|
|
|
|
|
content = content.strip()
|
|
|
|
|
if len(content) <= max_length:
|
|
|
|
|
return content
|
|
|
|
|
return content[:max_length] + "..."
|
|
|
|
|
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-04-12 20:50:21 +08:00
|
|
|
|
def normalize_extracted_info(name: str, is_entity=False) -> str:
|
2025-04-12 19:26:02 +08:00
|
|
|
|
"""Normalize entity/relation names and description with the following rules:
|
|
|
|
|
1. Remove spaces between Chinese characters
|
|
|
|
|
2. Remove spaces between Chinese characters and English letters/numbers
|
|
|
|
|
3. Preserve spaces within English text and numbers
|
|
|
|
|
4. Replace Chinese parentheses with English parentheses
|
|
|
|
|
5. Replace Chinese dash with English dash
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name: Entity name to normalize
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Normalized entity name
|
|
|
|
|
"""
|
|
|
|
|
# Replace Chinese parentheses with English parentheses
|
|
|
|
|
name = name.replace("(", "(").replace(")", ")")
|
|
|
|
|
|
|
|
|
|
# Replace Chinese dash with English dash
|
|
|
|
|
name = name.replace("—", "-").replace("-", "-")
|
|
|
|
|
|
|
|
|
|
# Use regex to remove spaces between Chinese characters
|
|
|
|
|
# Regex explanation:
|
|
|
|
|
# (?<=[\u4e00-\u9fa5]): Positive lookbehind for Chinese character
|
|
|
|
|
# \s+: One or more whitespace characters
|
|
|
|
|
# (?=[\u4e00-\u9fa5]): Positive lookahead for Chinese character
|
|
|
|
|
name = re.sub(r"(?<=[\u4e00-\u9fa5])\s+(?=[\u4e00-\u9fa5])", "", name)
|
|
|
|
|
|
|
|
|
|
# Remove spaces between Chinese and English/numbers
|
|
|
|
|
name = re.sub(r"(?<=[\u4e00-\u9fa5])\s+(?=[a-zA-Z0-9])", "", name)
|
|
|
|
|
name = re.sub(r"(?<=[a-zA-Z0-9])\s+(?=[\u4e00-\u9fa5])", "", name)
|
|
|
|
|
|
|
|
|
|
# Remove English quotation marks from the beginning and end
|
2025-04-17 23:00:34 +08:00
|
|
|
|
if len(name) >= 2 and name.startswith('"') and name.endswith('"'):
|
2025-04-17 22:58:36 +08:00
|
|
|
|
name = name[1:-1]
|
2025-04-12 19:26:02 +08:00
|
|
|
|
|
2025-04-12 20:45:41 +08:00
|
|
|
|
if is_entity:
|
|
|
|
|
# remove Chinese quotes
|
|
|
|
|
name = name.replace("“", "").replace("”", "").replace("‘", "").replace("’", "")
|
|
|
|
|
# remove English queotes in and around chinese
|
|
|
|
|
name = re.sub(r"['\"]+(?=[\u4e00-\u9fa5])", "", name)
|
|
|
|
|
name = re.sub(r"(?<=[\u4e00-\u9fa5])['\"]+", "", name)
|
|
|
|
|
|
2025-04-12 19:26:02 +08:00
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
2025-03-11 15:43:04 +08:00
|
|
|
|
def clean_text(text: str) -> str:
|
|
|
|
|
"""Clean text by removing null bytes (0x00) and whitespace
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: Input text to clean
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Cleaned text
|
|
|
|
|
"""
|
|
|
|
|
return text.strip().replace("\x00", "")
|
|
|
|
|
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-03-11 15:43:04 +08:00
|
|
|
|
def check_storage_env_vars(storage_name: str) -> None:
|
|
|
|
|
"""Check if all required environment variables for storage implementation exist
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
storage_name: Storage implementation name
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If required environment variables are missing
|
|
|
|
|
"""
|
|
|
|
|
from lightrag.kg import STORAGE_ENV_REQUIREMENTS
|
2025-03-11 15:44:01 +08:00
|
|
|
|
|
2025-03-11 15:43:04 +08:00
|
|
|
|
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
|
|
|
|
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
|
|
|
|
|
|
|
|
if missing_vars:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Storage implementation '{storage_name}' requires the following "
|
|
|
|
|
f"environment variables: {', '.join(missing_vars)}"
|
2025-03-11 15:44:01 +08:00
|
|
|
|
)
|
2025-03-28 01:25:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenTracker:
|
|
|
|
|
"""Track token usage for LLM calls."""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.reset()
|
|
|
|
|
|
2025-03-30 00:59:23 +08:00
|
|
|
|
def __enter__(self):
|
|
|
|
|
self.reset()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
print(self)
|
|
|
|
|
|
2025-03-28 01:25:15 +08:00
|
|
|
|
def reset(self):
|
|
|
|
|
self.prompt_tokens = 0
|
|
|
|
|
self.completion_tokens = 0
|
|
|
|
|
self.total_tokens = 0
|
|
|
|
|
self.call_count = 0
|
|
|
|
|
|
|
|
|
|
def add_usage(self, token_counts):
|
|
|
|
|
"""Add token usage from one LLM call.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
|
|
|
|
"""
|
|
|
|
|
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
|
|
|
|
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
|
|
|
|
|
|
|
|
|
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
|
|
|
|
if "total_tokens" in token_counts:
|
|
|
|
|
self.total_tokens += token_counts["total_tokens"]
|
|
|
|
|
else:
|
|
|
|
|
self.total_tokens += token_counts.get(
|
|
|
|
|
"prompt_tokens", 0
|
|
|
|
|
) + token_counts.get("completion_tokens", 0)
|
|
|
|
|
|
|
|
|
|
self.call_count += 1
|
|
|
|
|
|
|
|
|
|
def get_usage(self):
|
|
|
|
|
"""Get current usage statistics."""
|
|
|
|
|
return {
|
|
|
|
|
"prompt_tokens": self.prompt_tokens,
|
|
|
|
|
"completion_tokens": self.completion_tokens,
|
|
|
|
|
"total_tokens": self.total_tokens,
|
|
|
|
|
"call_count": self.call_count,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
usage = self.get_usage()
|
|
|
|
|
return (
|
|
|
|
|
f"LLM call count: {usage['call_count']}, "
|
|
|
|
|
f"Prompt tokens: {usage['prompt_tokens']}, "
|
|
|
|
|
f"Completion tokens: {usage['completion_tokens']}, "
|
|
|
|
|
f"Total tokens: {usage['total_tokens']}"
|
|
|
|
|
)
|