LightRAG/lightrag/utils.py

310 lines
9.4 KiB
Python
Raw Normal View History

2024-10-10 15:02:30 +08:00
import asyncio
import html
2024-10-31 14:31:26 +08:00
import io
import csv
2024-10-10 15:02:30 +08:00
import json
import logging
import os
import re
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
2024-11-06 11:18:14 -05:00
from typing import Any, Union, List
2024-10-20 23:08:26 +08:00
import xml.etree.ElementTree as ET
2024-10-10 15:02:30 +08:00
import numpy as np
import tiktoken
ENCODER = None
logger = logging.getLogger("lightrag")
2024-10-10 15:02:30 +08:00
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
2024-10-10 15:02:30 +08:00
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
2024-10-10 15:02:30 +08:00
@dataclass
class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
2024-10-10 15:02:30 +08:00
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
2024-11-25 13:29:55 +08:00
try:
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
maybe_json_str = maybe_json_str.group(0)
maybe_json_str = maybe_json_str.replace("\\n", "")
maybe_json_str = maybe_json_str.replace("\n", "")
maybe_json_str = maybe_json_str.replace("'", '"')
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
2024-11-25 13:29:55 +08:00
return maybe_json_str
2024-11-25 13:40:38 +08:00
except Exception:
pass
2024-11-25 13:29:55 +08:00
# try:
# content = (
# content.replace(kw_prompt[:-1], "")
# .replace("user", "")
# .replace("model", "")
# .strip()
# )
2024-11-25 13:40:38 +08:00
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
2024-11-25 13:29:55 +08:00
# json.loads(maybe_json_str)
2024-11-25 13:40:38 +08:00
2024-10-10 15:02:30 +08:00
return None
2024-10-10 15:02:30 +08:00
def convert_response_to_json(response: str) -> dict:
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
2024-10-10 15:02:30 +08:00
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
2024-10-10 15:02:30 +08:00
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
2024-10-10 15:02:30 +08:00
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
2024-10-10 15:02:30 +08:00
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
def final_decro(func) -> EmbeddingFunc:
new_func = EmbeddingFunc(**kwargs, func=func)
return new_func
return final_decro
2024-10-10 15:02:30 +08:00
def load_json(file_name):
if not os.path.exists(file_name):
return None
2024-10-11 11:24:42 +08:00
with open(file_name, encoding="utf-8") as f:
2024-10-10 15:02:30 +08:00
return json.load(f)
2024-10-10 15:02:30 +08:00
def write_json(json_obj, file_name):
2024-10-11 11:24:42 +08:00
with open(file_name, "w", encoding="utf-8") as f:
2024-10-10 15:02:30 +08:00
json.dump(json_obj, f, indent=2, ensure_ascii=False)
2024-10-10 15:02:30 +08:00
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
2024-10-10 15:02:30 +08:00
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
2024-10-10 15:02:30 +08:00
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
2024-10-10 15:02:30 +08:00
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
2024-10-10 15:02:30 +08:00
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
2024-10-10 15:02:30 +08:00
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
2024-11-06 11:18:14 -05:00
2024-10-31 14:31:26 +08:00
def list_of_list_to_csv(data: List[List[str]]) -> str:
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(data)
return output.getvalue()
2024-11-06 11:18:14 -05:00
2024-10-31 14:31:26 +08:00
def csv_string_to_list(csv_string: str) -> List[List[str]]:
output = io.StringIO(csv_string)
reader = csv.reader(output)
return [row for row in reader]
2024-10-10 15:02:30 +08:00
def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
2024-10-20 23:08:26 +08:00
2024-10-25 13:32:25 +05:30
2024-10-20 23:08:26 +08:00
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
2024-10-25 13:32:25 +05:30
data = {"nodes": [], "edges": []}
2024-10-20 23:08:26 +08:00
# Use namespace
2024-10-25 13:32:25 +05:30
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
2024-10-20 23:08:26 +08:00
2024-10-25 13:32:25 +05:30
for node in root.findall(".//node", namespace):
2024-10-20 23:08:26 +08:00
node_data = {
2024-10-25 13:32:25 +05:30
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
2024-10-20 23:08:26 +08:00
}
data["nodes"].append(node_data)
2024-10-25 13:32:25 +05:30
for edge in root.findall(".//edge", namespace):
2024-10-20 23:08:26 +08:00
edge_data = {
2024-10-25 13:32:25 +05:30
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
2024-10-20 23:08:26 +08:00
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
2024-11-06 11:18:14 -05:00
def process_combine_contexts(hl, ll):
header = None
2024-10-31 14:31:26 +08:00
list_hl = csv_string_to_list(hl.strip())
list_ll = csv_string_to_list(ll.strip())
2024-11-03 17:53:53 +08:00
if list_hl:
2024-11-06 11:18:14 -05:00
header = list_hl[0]
list_hl = list_hl[1:]
if list_ll:
header = list_ll[0]
list_ll = list_ll[1:]
if header is None:
return ""
2024-11-03 17:53:53 +08:00
if list_hl:
2024-11-06 11:18:14 -05:00
list_hl = [",".join(item[1:]) for item in list_hl if item]
if list_ll:
2024-11-06 11:18:14 -05:00
list_ll = [",".join(item[1:]) for item in list_ll if item]
2024-11-14 15:59:37 +08:00
combined_sources = []
seen = set()
2024-11-14 15:59:37 +08:00
for item in list_hl + list_ll:
if item and item not in seen:
combined_sources.append(item)
seen.add(item)
2024-11-03 17:53:53 +08:00
2024-11-14 15:59:37 +08:00
combined_sources_result = [",\t".join(header)]
2024-11-06 11:18:14 -05:00
2024-11-14 15:59:37 +08:00
for i, item in enumerate(combined_sources, start=1):
combined_sources_result.append(f"{i},\t{item}")
2024-11-14 15:59:37 +08:00
combined_sources_result = "\n".join(combined_sources_result)
return combined_sources_result