2025-02-15 22:37:12 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
import asyncio
|
|
|
|
import html
|
2024-10-31 14:31:26 +08:00
|
|
|
import io
|
|
|
|
import csv
|
2024-10-10 15:02:30 +08:00
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import wraps
|
|
|
|
from hashlib import md5
|
2025-02-15 22:37:12 +01:00
|
|
|
from typing import Any, Callable
|
2024-10-20 23:08:26 +08:00
|
|
|
import xml.etree.ElementTree as ET
|
2025-02-06 22:56:17 +03:00
|
|
|
import bs4
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import tiktoken
|
|
|
|
|
2024-12-08 17:35:52 +08:00
|
|
|
from lightrag.prompt import PROMPTS
|
2025-01-27 23:21:34 +08:00
|
|
|
|
2024-12-10 09:01:21 +08:00
|
|
|
|
|
|
|
class UnlimitedSemaphore:
|
|
|
|
"""A context manager that allows unlimited access."""
|
|
|
|
|
|
|
|
async def __aenter__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
ENCODER = None
|
|
|
|
|
2025-01-16 12:52:37 +08:00
|
|
|
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
logger = logging.getLogger("lightrag")
|
|
|
|
|
2025-01-16 12:52:37 +08:00
|
|
|
# Set httpx logging level to WARNING
|
|
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def set_logger(log_file: str):
|
|
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
2025-01-16 12:58:15 +08:00
|
|
|
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
2024-10-10 15:02:30 +08:00
|
|
|
file_handler.setLevel(logging.DEBUG)
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
formatter = logging.Formatter(
|
|
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
)
|
2024-10-10 15:02:30 +08:00
|
|
|
file_handler.setFormatter(formatter)
|
|
|
|
|
|
|
|
if not logger.handlers:
|
|
|
|
logger.addHandler(file_handler)
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
@dataclass
|
|
|
|
class EmbeddingFunc:
|
|
|
|
embedding_dim: int
|
|
|
|
max_token_size: int
|
|
|
|
func: callable
|
2025-02-01 22:07:12 +08:00
|
|
|
# concurrent_limit: int = 16
|
2024-12-10 09:01:21 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
2025-02-01 09:53:11 +08:00
|
|
|
return await self.func(*args, **kwargs)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
|
|
|
|
2025-02-06 22:56:17 +03:00
|
|
|
@dataclass
|
|
|
|
class ReasoningResponse:
|
2025-02-15 00:10:37 +01:00
|
|
|
reasoning_content: str | None
|
2025-02-06 22:56:17 +03:00
|
|
|
response_content: str
|
|
|
|
tag: str
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
def locate_json_string_body_from_string(content: str) -> str | None:
|
2024-10-10 15:02:30 +08:00
|
|
|
"""Locate the JSON string body from a string"""
|
2024-11-25 13:29:55 +08:00
|
|
|
try:
|
|
|
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
|
if maybe_json_str is not None:
|
|
|
|
maybe_json_str = maybe_json_str.group(0)
|
|
|
|
maybe_json_str = maybe_json_str.replace("\\n", "")
|
|
|
|
maybe_json_str = maybe_json_str.replace("\n", "")
|
|
|
|
maybe_json_str = maybe_json_str.replace("'", '"')
|
2024-11-29 21:41:37 +08:00
|
|
|
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
2024-11-25 13:29:55 +08:00
|
|
|
return maybe_json_str
|
2024-11-25 13:40:38 +08:00
|
|
|
except Exception:
|
|
|
|
pass
|
2024-11-25 13:29:55 +08:00
|
|
|
# try:
|
|
|
|
# content = (
|
|
|
|
# content.replace(kw_prompt[:-1], "")
|
|
|
|
# .replace("user", "")
|
|
|
|
# .replace("model", "")
|
|
|
|
# .strip()
|
|
|
|
# )
|
2024-11-25 13:40:38 +08:00
|
|
|
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
|
2024-11-25 13:29:55 +08:00
|
|
|
# json.loads(maybe_json_str)
|
2024-11-25 13:40:38 +08:00
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
return None
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2025-02-09 11:46:01 +01:00
|
|
|
def convert_response_to_json(response: str) -> dict[str, Any]:
|
2024-10-10 15:02:30 +08:00
|
|
|
json_str = locate_json_string_body_from_string(response)
|
|
|
|
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
|
|
|
try:
|
|
|
|
data = json.loads(json_str)
|
|
|
|
return data
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
logger.error(f"Failed to parse JSON: {json_str}")
|
|
|
|
raise e from None
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
def compute_args_hash(*args: Any, cache_type: str | None = None) -> str:
|
2025-01-24 18:59:24 +08:00
|
|
|
"""Compute a hash for the given arguments.
|
|
|
|
Args:
|
|
|
|
*args: Arguments to hash
|
2025-02-01 23:05:02 +08:00
|
|
|
cache_type: Type of cache (e.g., 'keywords', 'query', 'extract')
|
2025-01-24 18:59:24 +08:00
|
|
|
Returns:
|
|
|
|
str: Hash string
|
|
|
|
"""
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
# Convert all arguments to strings and join them
|
|
|
|
args_str = "".join([str(arg) for arg in args])
|
|
|
|
if cache_type:
|
|
|
|
args_str = f"{cache_type}:{args_str}"
|
|
|
|
|
|
|
|
# Compute MD5 hash
|
|
|
|
return hashlib.md5(args_str.encode()).hexdigest()
|
2024-10-10 15:02:30 +08:00
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2025-02-14 23:31:27 +01:00
|
|
|
def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
|
|
|
"""
|
|
|
|
Compute a unique ID for a given content string.
|
|
|
|
|
|
|
|
The ID is a combination of the given prefix and the MD5 hash of the content string.
|
|
|
|
"""
|
2024-10-10 15:02:30 +08:00
|
|
|
return prefix + md5(content.encode()).hexdigest()
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2025-02-01 10:36:15 +08:00
|
|
|
def limit_async_func_call(max_size: int):
|
|
|
|
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
def final_decro(func):
|
2025-02-01 10:36:15 +08:00
|
|
|
sem = asyncio.Semaphore(max_size)
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
async def wait_func(*args, **kwargs):
|
2025-02-01 10:36:15 +08:00
|
|
|
async with sem:
|
|
|
|
result = await func(*args, **kwargs)
|
|
|
|
return result
|
2024-10-10 15:02:30 +08:00
|
|
|
|
|
|
|
return wait_func
|
|
|
|
|
|
|
|
return final_decro
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def wrap_embedding_func_with_attrs(**kwargs):
|
|
|
|
"""Wrap a function with attributes"""
|
|
|
|
|
|
|
|
def final_decro(func) -> EmbeddingFunc:
|
|
|
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
|
|
|
return new_func
|
|
|
|
|
|
|
|
return final_decro
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def load_json(file_name):
|
|
|
|
if not os.path.exists(file_name):
|
|
|
|
return None
|
2024-10-11 11:24:42 +08:00
|
|
|
with open(file_name, encoding="utf-8") as f:
|
2024-10-10 15:02:30 +08:00
|
|
|
return json.load(f)
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def write_json(json_obj, file_name):
|
2024-10-11 11:24:42 +08:00
|
|
|
with open(file_name, "w", encoding="utf-8") as f:
|
2024-10-10 15:02:30 +08:00
|
|
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
|
|
|
global ENCODER
|
|
|
|
if ENCODER is None:
|
|
|
|
ENCODER = tiktoken.encoding_for_model(model_name)
|
|
|
|
tokens = ENCODER.encode(content)
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
|
|
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|
|
|
global ENCODER
|
|
|
|
if ENCODER is None:
|
|
|
|
ENCODER = tiktoken.encoding_for_model(model_name)
|
|
|
|
content = ENCODER.decode(tokens)
|
|
|
|
return content
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def pack_user_ass_to_openai_messages(*args: str):
|
|
|
|
roles = ["user", "assistant"]
|
|
|
|
return [
|
|
|
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
|
|
|
]
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
|
|
|
"""Split a string by multiple markers"""
|
|
|
|
if not markers:
|
|
|
|
return [content]
|
|
|
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
|
|
|
return [r.strip() for r in results if r.strip()]
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
# Refer the utils functions of the official GraphRAG implementation:
|
|
|
|
# https://github.com/microsoft/graphrag
|
|
|
|
def clean_str(input: Any) -> str:
|
|
|
|
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
|
|
|
# If we get non-string input, just give it back
|
|
|
|
if not isinstance(input, str):
|
|
|
|
return input
|
|
|
|
|
|
|
|
result = html.unescape(input.strip())
|
|
|
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
|
|
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
def is_float_regex(value: str) -> bool:
|
2024-10-10 15:02:30 +08:00
|
|
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
def truncate_list_by_token_size(
|
|
|
|
list_data: list[Any], key: Callable[[Any], str], max_token_size: int
|
|
|
|
) -> list[int]:
|
2024-10-10 15:02:30 +08:00
|
|
|
"""Truncate a list of data by token size"""
|
|
|
|
if max_token_size <= 0:
|
|
|
|
return []
|
|
|
|
tokens = 0
|
|
|
|
for i, data in enumerate(list_data):
|
|
|
|
tokens += len(encode_string_by_tiktoken(key(data)))
|
|
|
|
if tokens > max_token_size:
|
|
|
|
return list_data[:i]
|
|
|
|
return list_data
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
def list_of_list_to_csv(data: list[list[str]]) -> str:
|
2024-10-31 14:31:26 +08:00
|
|
|
output = io.StringIO()
|
2025-01-27 10:13:06 +01:00
|
|
|
writer = csv.writer(
|
|
|
|
output,
|
2025-01-27 23:21:34 +08:00
|
|
|
quoting=csv.QUOTE_ALL, # Quote all fields
|
|
|
|
escapechar="\\", # Use backslash as escape character
|
|
|
|
quotechar='"', # Use double quotes
|
|
|
|
lineterminator="\n", # Explicit line terminator
|
2025-01-27 10:13:06 +01:00
|
|
|
)
|
2024-10-31 14:31:26 +08:00
|
|
|
writer.writerows(data)
|
|
|
|
return output.getvalue()
|
2024-11-06 11:18:14 -05:00
|
|
|
|
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
def csv_string_to_list(csv_string: str) -> list[list[str]]:
|
2025-01-27 10:15:30 +01:00
|
|
|
# Clean the string by removing NUL characters
|
2025-01-27 23:21:34 +08:00
|
|
|
cleaned_string = csv_string.replace("\0", "")
|
|
|
|
|
2025-01-27 10:15:30 +01:00
|
|
|
output = io.StringIO(cleaned_string)
|
|
|
|
reader = csv.reader(
|
|
|
|
output,
|
2025-01-27 23:21:34 +08:00
|
|
|
quoting=csv.QUOTE_ALL, # Match the writer configuration
|
|
|
|
escapechar="\\", # Use backslash as escape character
|
|
|
|
quotechar='"', # Use double quotes
|
2025-01-27 10:15:30 +01:00
|
|
|
)
|
2025-01-27 23:21:34 +08:00
|
|
|
|
2025-01-27 10:15:30 +01:00
|
|
|
try:
|
|
|
|
return [row for row in reader]
|
|
|
|
except csv.Error as e:
|
|
|
|
raise ValueError(f"Failed to parse CSV string: {str(e)}")
|
|
|
|
finally:
|
|
|
|
output.close()
|
2024-10-31 14:31:26 +08:00
|
|
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
def save_data_to_file(data, file_name):
|
2024-10-19 09:43:17 +05:30
|
|
|
with open(file_name, "w", encoding="utf-8") as f:
|
|
|
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
2024-10-20 23:08:26 +08:00
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
|
2024-10-20 23:08:26 +08:00
|
|
|
def xml_to_json(xml_file):
|
|
|
|
try:
|
|
|
|
tree = ET.parse(xml_file)
|
|
|
|
root = tree.getroot()
|
|
|
|
|
|
|
|
# Print the root element's tag and attributes to confirm the file has been correctly loaded
|
|
|
|
print(f"Root element: {root.tag}")
|
|
|
|
print(f"Root attributes: {root.attrib}")
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
data = {"nodes": [], "edges": []}
|
2024-10-20 23:08:26 +08:00
|
|
|
|
|
|
|
# Use namespace
|
2024-10-25 13:32:25 +05:30
|
|
|
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
2024-10-20 23:08:26 +08:00
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
for node in root.findall(".//node", namespace):
|
2024-10-20 23:08:26 +08:00
|
|
|
node_data = {
|
2024-10-25 13:32:25 +05:30
|
|
|
"id": node.get("id").strip('"'),
|
|
|
|
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
|
|
|
if node.find("./data[@key='d0']", namespace) is not None
|
|
|
|
else "",
|
|
|
|
"description": node.find("./data[@key='d1']", namespace).text
|
|
|
|
if node.find("./data[@key='d1']", namespace) is not None
|
|
|
|
else "",
|
|
|
|
"source_id": node.find("./data[@key='d2']", namespace).text
|
|
|
|
if node.find("./data[@key='d2']", namespace) is not None
|
|
|
|
else "",
|
2024-10-20 23:08:26 +08:00
|
|
|
}
|
|
|
|
data["nodes"].append(node_data)
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
for edge in root.findall(".//edge", namespace):
|
2024-10-20 23:08:26 +08:00
|
|
|
edge_data = {
|
2024-10-25 13:32:25 +05:30
|
|
|
"source": edge.get("source").strip('"'),
|
|
|
|
"target": edge.get("target").strip('"'),
|
|
|
|
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
|
|
|
if edge.find("./data[@key='d3']", namespace) is not None
|
|
|
|
else 0.0,
|
|
|
|
"description": edge.find("./data[@key='d4']", namespace).text
|
|
|
|
if edge.find("./data[@key='d4']", namespace) is not None
|
|
|
|
else "",
|
|
|
|
"keywords": edge.find("./data[@key='d5']", namespace).text
|
|
|
|
if edge.find("./data[@key='d5']", namespace) is not None
|
|
|
|
else "",
|
|
|
|
"source_id": edge.find("./data[@key='d6']", namespace).text
|
|
|
|
if edge.find("./data[@key='d6']", namespace) is not None
|
|
|
|
else "",
|
2024-10-20 23:08:26 +08:00
|
|
|
}
|
|
|
|
data["edges"].append(edge_data)
|
|
|
|
|
|
|
|
# Print the number of nodes and edges found
|
|
|
|
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
|
|
|
|
|
|
|
return data
|
|
|
|
except ET.ParseError as e:
|
|
|
|
print(f"Error parsing XML file: {e}")
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
|
|
print(f"An error occurred: {e}")
|
|
|
|
return None
|
2024-10-31 11:34:01 +08:00
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
def process_combine_contexts(hl: str, ll: str):
|
2024-10-31 11:34:01 +08:00
|
|
|
header = None
|
2024-10-31 14:31:26 +08:00
|
|
|
list_hl = csv_string_to_list(hl.strip())
|
|
|
|
list_ll = csv_string_to_list(ll.strip())
|
2024-11-03 17:53:53 +08:00
|
|
|
|
2024-10-31 11:34:01 +08:00
|
|
|
if list_hl:
|
2024-11-06 11:18:14 -05:00
|
|
|
header = list_hl[0]
|
2024-10-31 11:34:01 +08:00
|
|
|
list_hl = list_hl[1:]
|
|
|
|
if list_ll:
|
|
|
|
header = list_ll[0]
|
|
|
|
list_ll = list_ll[1:]
|
|
|
|
if header is None:
|
|
|
|
return ""
|
2024-11-03 17:53:53 +08:00
|
|
|
|
2024-10-31 11:34:01 +08:00
|
|
|
if list_hl:
|
2024-11-06 11:18:14 -05:00
|
|
|
list_hl = [",".join(item[1:]) for item in list_hl if item]
|
2024-10-31 11:34:01 +08:00
|
|
|
if list_ll:
|
2024-11-06 11:18:14 -05:00
|
|
|
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
2024-10-31 11:34:01 +08:00
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
combined_sources = []
|
|
|
|
seen = set()
|
2024-10-31 11:34:01 +08:00
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
for item in list_hl + list_ll:
|
|
|
|
if item and item not in seen:
|
|
|
|
combined_sources.append(item)
|
|
|
|
seen.add(item)
|
2024-11-03 17:53:53 +08:00
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
combined_sources_result = [",\t".join(header)]
|
2024-11-06 11:18:14 -05:00
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
for i, item in enumerate(combined_sources, start=1):
|
|
|
|
combined_sources_result.append(f"{i},\t{item}")
|
2024-10-31 11:34:01 +08:00
|
|
|
|
2024-11-14 15:59:37 +08:00
|
|
|
combined_sources_result = "\n".join(combined_sources_result)
|
|
|
|
|
|
|
|
return combined_sources_result
|
2024-12-06 08:17:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
async def get_best_cached_response(
|
2024-12-06 14:29:16 +08:00
|
|
|
hashing_kv,
|
|
|
|
current_embedding,
|
|
|
|
similarity_threshold=0.95,
|
|
|
|
mode="default",
|
2024-12-08 17:35:52 +08:00
|
|
|
use_llm_check=False,
|
|
|
|
llm_func=None,
|
|
|
|
original_prompt=None,
|
2025-01-24 18:59:24 +08:00
|
|
|
cache_type=None,
|
2025-02-15 22:37:12 +01:00
|
|
|
) -> str | None:
|
2025-02-02 04:27:55 +08:00
|
|
|
logger.debug(
|
|
|
|
f"get_best_cached_response: mode={mode} cache_type={cache_type} use_llm_check={use_llm_check}"
|
|
|
|
)
|
2024-12-06 14:29:16 +08:00
|
|
|
mode_cache = await hashing_kv.get_by_id(mode)
|
|
|
|
if not mode_cache:
|
2024-12-06 08:17:20 +08:00
|
|
|
return None
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
best_similarity = -1
|
|
|
|
best_response = None
|
|
|
|
best_prompt = None
|
|
|
|
best_cache_id = None
|
|
|
|
|
|
|
|
# Only iterate through cache entries for this mode
|
|
|
|
for cache_id, cache_data in mode_cache.items():
|
2025-01-24 18:59:24 +08:00
|
|
|
# Skip if cache_type doesn't match
|
|
|
|
if cache_type and cache_data.get("cache_type") != cache_type:
|
|
|
|
continue
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
if cache_data["embedding"] is None:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Convert cached embedding list to ndarray
|
|
|
|
cached_quantized = np.frombuffer(
|
|
|
|
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
|
|
|
).reshape(cache_data["embedding_shape"])
|
|
|
|
cached_embedding = dequantize_embedding(
|
|
|
|
cached_quantized,
|
|
|
|
cache_data["embedding_min"],
|
|
|
|
cache_data["embedding_max"],
|
|
|
|
)
|
|
|
|
|
|
|
|
similarity = cosine_similarity(current_embedding, cached_embedding)
|
|
|
|
if similarity > best_similarity:
|
|
|
|
best_similarity = similarity
|
|
|
|
best_response = cache_data["return"]
|
|
|
|
best_prompt = cache_data["original_prompt"]
|
|
|
|
best_cache_id = cache_id
|
|
|
|
|
|
|
|
if best_similarity > similarity_threshold:
|
2024-12-08 17:35:52 +08:00
|
|
|
# If LLM check is enabled and all required parameters are provided
|
2025-02-13 14:07:36 +08:00
|
|
|
if (
|
|
|
|
use_llm_check
|
|
|
|
and llm_func
|
|
|
|
and original_prompt
|
|
|
|
and best_prompt
|
|
|
|
and best_response is not None
|
|
|
|
):
|
2024-12-08 17:35:52 +08:00
|
|
|
compare_prompt = PROMPTS["similarity_check"].format(
|
|
|
|
original_prompt=original_prompt, cached_prompt=best_prompt
|
|
|
|
)
|
|
|
|
|
|
|
|
try:
|
|
|
|
llm_result = await llm_func(compare_prompt)
|
|
|
|
llm_result = llm_result.strip()
|
|
|
|
llm_similarity = float(llm_result)
|
|
|
|
|
|
|
|
# Replace vector similarity with LLM similarity score
|
|
|
|
best_similarity = llm_similarity
|
|
|
|
if best_similarity < similarity_threshold:
|
|
|
|
log_data = {
|
2025-02-13 13:53:52 +08:00
|
|
|
"event": "cache_rejected_by_llm",
|
|
|
|
"type": cache_type,
|
|
|
|
"mode": mode,
|
2024-12-08 17:35:52 +08:00
|
|
|
"original_question": original_prompt[:100] + "..."
|
|
|
|
if len(original_prompt) > 100
|
|
|
|
else original_prompt,
|
|
|
|
"cached_question": best_prompt[:100] + "..."
|
|
|
|
if len(best_prompt) > 100
|
|
|
|
else best_prompt,
|
|
|
|
"similarity_score": round(best_similarity, 4),
|
|
|
|
"threshold": similarity_threshold,
|
|
|
|
}
|
2025-02-13 13:53:52 +08:00
|
|
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
|
|
|
logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})")
|
2024-12-08 17:35:52 +08:00
|
|
|
return None
|
|
|
|
except Exception as e: # Catch all possible exceptions
|
|
|
|
logger.warning(f"LLM similarity check failed: {e}")
|
|
|
|
return None # Return None directly when LLM check fails
|
|
|
|
|
2024-12-06 14:29:16 +08:00
|
|
|
prompt_display = (
|
|
|
|
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
|
|
|
)
|
|
|
|
log_data = {
|
|
|
|
"event": "cache_hit",
|
2025-02-13 13:53:52 +08:00
|
|
|
"type": cache_type,
|
2024-12-06 14:29:16 +08:00
|
|
|
"mode": mode,
|
|
|
|
"similarity": round(best_similarity, 4),
|
|
|
|
"cache_id": best_cache_id,
|
|
|
|
"original_prompt": prompt_display,
|
|
|
|
}
|
2025-02-13 13:53:52 +08:00
|
|
|
logger.debug(json.dumps(log_data, ensure_ascii=False))
|
2024-12-06 14:29:16 +08:00
|
|
|
return best_response
|
|
|
|
return None
|
2024-12-06 08:17:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
def cosine_similarity(v1, v2):
|
|
|
|
"""Calculate cosine similarity between two vectors"""
|
|
|
|
dot_product = np.dot(v1, v2)
|
|
|
|
norm1 = np.linalg.norm(v1)
|
|
|
|
norm2 = np.linalg.norm(v2)
|
|
|
|
return dot_product / (norm1 * norm2)
|
|
|
|
|
|
|
|
|
2025-02-15 22:37:32 +01:00
|
|
|
def quantize_embedding(embedding: np.ndarray | list[float], bits: int = 8) -> tuple:
|
2024-12-06 08:17:20 +08:00
|
|
|
"""Quantize embedding to specified bits"""
|
2025-02-01 09:55:05 +08:00
|
|
|
# Convert list to numpy array if needed
|
|
|
|
if isinstance(embedding, list):
|
|
|
|
embedding = np.array(embedding)
|
2025-02-01 15:22:40 +08:00
|
|
|
|
2024-12-06 08:17:20 +08:00
|
|
|
# Calculate min/max values for reconstruction
|
|
|
|
min_val = embedding.min()
|
|
|
|
max_val = embedding.max()
|
|
|
|
|
|
|
|
# Quantize to 0-255 range
|
|
|
|
scale = (2**bits - 1) / (max_val - min_val)
|
|
|
|
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
|
|
|
|
|
|
|
return quantized, min_val, max_val
|
|
|
|
|
|
|
|
|
|
|
|
def dequantize_embedding(
|
|
|
|
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
|
|
|
) -> np.ndarray:
|
|
|
|
"""Restore quantized embedding"""
|
|
|
|
scale = (max_val - min_val) / (2**bits - 1)
|
|
|
|
return (quantized * scale + min_val).astype(np.float32)
|
2024-12-08 10:37:55 +08:00
|
|
|
|
2024-12-08 17:35:52 +08:00
|
|
|
|
2025-01-31 15:33:50 +08:00
|
|
|
async def handle_cache(
|
2025-02-02 01:56:32 +08:00
|
|
|
hashing_kv,
|
|
|
|
args_hash,
|
|
|
|
prompt,
|
|
|
|
mode="default",
|
|
|
|
cache_type=None,
|
|
|
|
force_llm_cache=False,
|
2025-01-31 15:33:50 +08:00
|
|
|
):
|
2024-12-08 10:37:55 +08:00
|
|
|
"""Generic cache handling function"""
|
2025-02-02 01:56:32 +08:00
|
|
|
if hashing_kv is None or not (
|
|
|
|
force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
|
|
|
|
):
|
2024-12-08 10:37:55 +08:00
|
|
|
return None, None, None, None
|
|
|
|
|
2025-02-02 00:10:21 +08:00
|
|
|
if mode != "default":
|
|
|
|
# Get embedding cache configuration
|
|
|
|
embedding_cache_config = hashing_kv.global_config.get(
|
|
|
|
"embedding_cache_config",
|
|
|
|
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
2025-02-01 22:12:45 +08:00
|
|
|
)
|
2025-02-02 00:10:21 +08:00
|
|
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
|
|
|
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
|
|
|
|
|
|
|
quantized = min_val = max_val = None
|
|
|
|
if is_embedding_cache_enabled:
|
|
|
|
# Use embedding cache
|
|
|
|
current_embedding = await hashing_kv.embedding_func([prompt])
|
2025-02-02 04:27:55 +08:00
|
|
|
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
2025-02-02 00:10:21 +08:00
|
|
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
|
|
best_cached_response = await get_best_cached_response(
|
|
|
|
hashing_kv,
|
|
|
|
current_embedding[0],
|
|
|
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
|
|
mode=mode,
|
|
|
|
use_llm_check=use_llm_check,
|
2025-02-02 01:28:46 +08:00
|
|
|
llm_func=llm_model_func if use_llm_check else None,
|
2025-02-02 03:09:06 +08:00
|
|
|
original_prompt=prompt,
|
2025-02-02 00:10:21 +08:00
|
|
|
cache_type=cache_type,
|
|
|
|
)
|
|
|
|
if best_cached_response is not None:
|
2025-02-13 13:53:52 +08:00
|
|
|
logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
return best_cached_response, None, None, None
|
|
|
|
else:
|
2025-02-13 13:53:52 +08:00
|
|
|
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
|
|
|
logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
return None, quantized, min_val, max_val
|
|
|
|
|
2025-02-13 13:53:52 +08:00
|
|
|
# For default mode or is_embedding_cache_enabled is False, use regular cache
|
|
|
|
# default mode is for extract_entities or naive query
|
2025-02-02 00:10:21 +08:00
|
|
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
|
|
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
2024-12-08 10:37:55 +08:00
|
|
|
else:
|
2025-02-02 00:10:21 +08:00
|
|
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
|
|
|
if args_hash in mode_cache:
|
2025-02-13 13:53:52 +08:00
|
|
|
logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
return mode_cache[args_hash]["return"], None, None, None
|
|
|
|
|
2025-02-13 13:53:52 +08:00
|
|
|
logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
|
2025-02-02 00:10:21 +08:00
|
|
|
return None, None, None, None
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class CacheData:
|
|
|
|
args_hash: str
|
|
|
|
content: str
|
|
|
|
prompt: str
|
2025-02-15 22:37:12 +01:00
|
|
|
quantized: np.ndarray | None = None
|
|
|
|
min_val: float | None = None
|
|
|
|
max_val: float | None = None
|
2024-12-08 10:37:55 +08:00
|
|
|
mode: str = "default"
|
2025-01-25 00:55:07 +01:00
|
|
|
cache_type: str = "query"
|
|
|
|
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
2024-12-09 12:36:55 +08:00
|
|
|
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
2024-12-08 10:37:55 +08:00
|
|
|
return
|
|
|
|
|
2025-01-06 12:50:05 +08:00
|
|
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
|
|
|
mode_cache = (
|
|
|
|
await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
|
|
|
|
or {}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
2024-12-08 10:37:55 +08:00
|
|
|
|
|
|
|
mode_cache[cache_data.args_hash] = {
|
|
|
|
"return": cache_data.content,
|
2025-02-01 22:27:49 +08:00
|
|
|
"cache_type": cache_data.cache_type,
|
2024-12-08 10:37:55 +08:00
|
|
|
"embedding": cache_data.quantized.tobytes().hex()
|
|
|
|
if cache_data.quantized is not None
|
|
|
|
else None,
|
|
|
|
"embedding_shape": cache_data.quantized.shape
|
|
|
|
if cache_data.quantized is not None
|
|
|
|
else None,
|
|
|
|
"embedding_min": cache_data.min_val,
|
|
|
|
"embedding_max": cache_data.max_val,
|
|
|
|
"original_prompt": cache_data.prompt,
|
|
|
|
}
|
|
|
|
|
|
|
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
2024-12-09 19:14:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
def safe_unicode_decode(content):
|
|
|
|
# Regular expression to find all Unicode escape sequences of the form \uXXXX
|
|
|
|
unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
|
|
|
|
|
|
|
|
# Function to replace the Unicode escape with the actual character
|
|
|
|
def replace_unicode_escape(match):
|
|
|
|
# Convert the matched hexadecimal value into the actual Unicode character
|
|
|
|
return chr(int(match.group(1), 16))
|
|
|
|
|
|
|
|
# Perform the substitution
|
|
|
|
decoded_content = unicode_escape_pattern.sub(
|
|
|
|
replace_unicode_escape, content.decode("utf-8")
|
|
|
|
)
|
|
|
|
|
|
|
|
return decoded_content
|
2025-01-06 12:50:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
def exists_func(obj, func_name: str) -> bool:
|
|
|
|
"""Check if a function exists in an object or not.
|
|
|
|
:param obj:
|
|
|
|
:param func_name:
|
|
|
|
:return: True / False
|
|
|
|
"""
|
|
|
|
if callable(getattr(obj, func_name, None)):
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
|
2025-02-15 00:10:37 +01:00
|
|
|
def get_conversation_turns(
|
|
|
|
conversation_history: list[dict[str, Any]], num_turns: int
|
|
|
|
) -> str:
|
2025-01-24 18:59:24 +08:00
|
|
|
"""
|
|
|
|
Process conversation history to get the specified number of complete turns.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
conversation_history: List of conversation messages in chronological order
|
|
|
|
num_turns: Number of complete turns to include
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Formatted string of the conversation history
|
|
|
|
"""
|
|
|
|
# Group messages into turns
|
2025-02-15 00:10:37 +01:00
|
|
|
turns: list[list[dict[str, Any]]] = []
|
|
|
|
messages: list[dict[str, Any]] = []
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
# First, filter out keyword extraction messages
|
|
|
|
for msg in conversation_history:
|
|
|
|
if msg["role"] == "assistant" and (
|
|
|
|
msg["content"].startswith('{ "high_level_keywords"')
|
|
|
|
or msg["content"].startswith("{'high_level_keywords'")
|
|
|
|
):
|
|
|
|
continue
|
|
|
|
messages.append(msg)
|
|
|
|
|
|
|
|
# Then process messages in chronological order
|
|
|
|
i = 0
|
|
|
|
while i < len(messages) - 1:
|
|
|
|
msg1 = messages[i]
|
|
|
|
msg2 = messages[i + 1]
|
|
|
|
|
|
|
|
# Check if we have a user-assistant or assistant-user pair
|
|
|
|
if (msg1["role"] == "user" and msg2["role"] == "assistant") or (
|
|
|
|
msg1["role"] == "assistant" and msg2["role"] == "user"
|
|
|
|
):
|
|
|
|
# Always put user message first in the turn
|
|
|
|
if msg1["role"] == "assistant":
|
|
|
|
turn = [msg2, msg1] # user, assistant
|
|
|
|
else:
|
|
|
|
turn = [msg1, msg2] # user, assistant
|
|
|
|
turns.append(turn)
|
2025-01-30 13:08:27 +08:00
|
|
|
i += 2
|
2025-01-24 18:59:24 +08:00
|
|
|
|
|
|
|
# Keep only the most recent num_turns
|
|
|
|
if len(turns) > num_turns:
|
|
|
|
turns = turns[-num_turns:]
|
|
|
|
|
|
|
|
# Format the turns into a string
|
2025-02-15 00:10:37 +01:00
|
|
|
formatted_turns: list[str] = []
|
2025-01-24 18:59:24 +08:00
|
|
|
for turn in turns:
|
|
|
|
formatted_turns.extend(
|
|
|
|
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
|
|
|
)
|
|
|
|
|
|
|
|
return "\n".join(formatted_turns)
|
2025-02-06 22:56:17 +03:00
|
|
|
|
|
|
|
|
|
|
|
def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
|
|
|
|
"""Extract the reasoning section and the following section from the LLM response.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
response: LLM response
|
|
|
|
tag: Tag to extract
|
|
|
|
Returns:
|
|
|
|
ReasoningResponse: Reasoning section and following section
|
|
|
|
|
|
|
|
"""
|
|
|
|
soup = bs4.BeautifulSoup(response, "html.parser")
|
|
|
|
|
|
|
|
reasoning_section = soup.find(tag)
|
|
|
|
if reasoning_section is None:
|
|
|
|
return ReasoningResponse(None, response, tag)
|
|
|
|
reasoning_content = reasoning_section.get_text().strip()
|
|
|
|
|
|
|
|
after_reasoning_section = reasoning_section.next_sibling
|
|
|
|
if after_reasoning_section is None:
|
|
|
|
return ReasoningResponse(reasoning_content, "", tag)
|
|
|
|
after_reasoning_content = after_reasoning_section.get_text().strip()
|
|
|
|
|
|
|
|
return ReasoningResponse(reasoning_content, after_reasoning_content, tag)
|