mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-11-11 16:08:33 +00:00
* feat(kag): update to v0.7 (#456) * add think cost * update csv scanner * add final rerank * add reasoner * add iterative planner * fix dpr search * fix dpr search * add reference data * move odps import * update requirement.txt * update 2wiki * add missing file * fix markdown reader * add iterative planning * update version * update runner * update 2wiki example * update bridge * merge solver and solver_new * add cur day * writer delete * update multi process * add missing files * fix report * add chunk retrieved executor * update try in stream runner result * add path * add math executor * update hotpotqa example * remove log * fix python coder solver * update hotpotqa example * fix python coder solver * update config * fix bad * add log * remove unused code * commit with task thought * move kag model to common * add default chat llm * fix * use static planner * support chunk graph node * add args * support naive rag * llm client support tool calls * add default async * add openai * fix result * fix markdown reader * fix thinker * update asyncio interface * feat(solver): add mcp support (#444) * 上传mcp client相关代码 * 1、完成一套mcp client的调用,从pipeline到planner、executor 2、允许json中传入多个mcp_server,通过大模型进行调用并选择 3、调通baidu_map_mcp的使用 * 1、schema * bugfix:删减冗余代码 --------- Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> * fix affairqa after solver refactor * fix affairqa after solver refactor * fix readme * add params * update version * update mcp executor * update mcp executor * solver add mcp executor * add missing file * add mpc executor * add executor * x * update * fix requirement * fix main llm config * fix solver * bugfix:修复invoke函数调用逻辑 * chg eva * update example * add kag layer * add step task * support dot refresh * support dot refresh * support dot refresh * support dot refresh * add retrieved num * add retrieved num * add pipelineconf * update ppr * update musique prompts * update * add to_dict for BuilderComponentData * async build * add deduce prompt * add deduce prompt * add deduce prompt * fix reader * add deduce prompt * add page thinker report * modify prmpt * add step status * add self cognition * add self cognition * add memory graph storage * add now time * update memory config * add now time * chg graph loader * 添加prqa数据集和代码 * bugfix:prqa调用逻辑修复 * optimize:优化代码逻辑,生成答案规范化 * add retry py code * update memory graph * update memory graph * fix * fix ner * add with_out_refer generator prompt * fix * close ckpt * fix query * fix query * update version * add llm checker * add llm checker * 1、上传evalutor.py以及修改gold_answer.json格式 2、优化代码逻辑 3、修改README.md文件 * update exp * update exp * rerank support * add static rewrite query * recall more chunks * fix graph load * add static rewrite query * fix bugs * add finish check * add finish check * add finish check * add finish check * 1、上传evalutor.py的结果 2、优化代码逻辑,优化readme文件 * add lf retry * add memory graph api * fix reader api * add ner * add metrics * fix bug * remove ner * add reraise fo retry * add edge prop to memory graph * add memory graph * 1、评测数据集结果修正 2、优化evaluator.py代码 3、删除结果不存在而gold_answer中有答案的问题 * 删除评测结果文件 * fix knext host addr * async eva * add lf prompt * add lf prompt * add config * add retry * add unknown check * add rc result * add rc result * add rc result * add rc result * 依据kag pipeline格式修改代码逻辑并通过测试 * bugfix:删除冗余代码 * fix report prompt * bugfix:触发重试机制 * bugfix:中文符号错误 * fix rethinker prompt * update version to 0.6.2b78 * update version * 1、修改evaluator.py,通过大模型计算准确率,符合最新调用逻辑 2、修改prompt,让没有回答的结果重复测试 * update affairqa for evaluate * update affairqa for evaluate * bugfix:修正数据集 * bugfix:修正数据集 * bugfix:修正数据集 * fix name conflict * bugfix:删除错误问题 * bugfix:文件名命名错误导致evaluator失败 * update for affairqa eval * bugfix:修改代码保持evaluate逻辑一致 * x * update for affairqa readme * remove temp eval scripts * bugfix for math deduce * merge 0.6.2_dev * merge 0.6.2_dev * fix * update client addr * updated version * update for affairqa eval * evaUtils 支持中文 * fix affairqa eval: * remove unused example * update kag config * fix default value * update readme * fix init * 注释信息修改,并添加部分class说明 * update example config * Tc 0.7.0 (#459) * 提交affairQA 代码 * fix affairqa eval --------- Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * fix all examples * reformat --------- Co-authored-by: peilong <peilong.zpl@antgroup.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * update chunk metadata * update chunk metadata * add debug reporter * update table text * add server * fix math executor * update api-key for openai vec * update * fix naive rag bug * format code * fix --------- Co-authored-by: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
518 lines
14 KiB
Python
518 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 OpenSPG Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
# or implied.
|
|
# flake8: noqa
|
|
import datetime
|
|
import random
|
|
import re
|
|
import string
|
|
import sys
|
|
import json
|
|
import hashlib
|
|
import os
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
import subprocess
|
|
import shlex
|
|
|
|
import requests
|
|
import importlib
|
|
import numpy as np
|
|
from typing import Tuple, TypeVar, Optional
|
|
from pathlib import Path
|
|
|
|
from shutil import copystat, copy2
|
|
from typing import Any, Union
|
|
from jinja2 import Environment, FileSystemLoader, Template
|
|
from stat import S_IWUSR as OWNER_WRITE_PERMISSION
|
|
from tenacity import retry, stop_after_attempt
|
|
from aiolimiter import AsyncLimiter
|
|
|
|
reset = "\033[0m"
|
|
bold = "\033[1m"
|
|
underline = "\033[4m"
|
|
red = "\033[31m"
|
|
green = "\033[32m"
|
|
yellow = "\033[33m"
|
|
blue = "\033[34m"
|
|
magenta = "\033[35m"
|
|
cyan = "\033[36m"
|
|
white = "\033[37m"
|
|
|
|
|
|
def run_cmd(cmd, catch_stdout=True, catch_stderr=True, shell=False):
|
|
args = shlex.split(cmd)
|
|
if catch_stdout:
|
|
stdout = subprocess.PIPE
|
|
else:
|
|
stdout = None
|
|
if catch_stderr:
|
|
stderr = subprocess.PIPE
|
|
else:
|
|
stderr = None
|
|
result = subprocess.run(args, stdout=stdout, stderr=stderr, shell=shell)
|
|
return result
|
|
|
|
|
|
def append_python_path(path: str) -> bool:
|
|
"""
|
|
Append the given path to `sys.path`.
|
|
"""
|
|
path = Path(path).resolve()
|
|
path = str(path)
|
|
if path not in sys.path:
|
|
sys.path.append(path)
|
|
return True
|
|
return False
|
|
|
|
|
|
def render_template(
|
|
root_dir: Union[str, os.PathLike], file: Union[str, os.PathLike], **kwargs: Any
|
|
) -> None:
|
|
env = Environment(loader=FileSystemLoader(root_dir))
|
|
template = env.get_template(str(file))
|
|
content = template.render(kwargs)
|
|
|
|
path_obj = Path(root_dir) / file
|
|
render_path = path_obj.with_suffix("") if path_obj.suffix == ".tmpl" else path_obj
|
|
|
|
if path_obj.suffix == ".tmpl":
|
|
path_obj.rename(render_path)
|
|
render_path.write_text(content, "utf8")
|
|
|
|
|
|
def copytree(src: Path, dst: Path, **kwargs):
|
|
names = [x.name for x in src.iterdir()]
|
|
|
|
if not dst.exists():
|
|
dst.mkdir(parents=True)
|
|
|
|
for name in names:
|
|
_name = Template(name).render(**kwargs)
|
|
src_name = src / name
|
|
dst_name = dst / _name
|
|
if src_name.is_dir():
|
|
copytree(src_name, dst_name, **kwargs)
|
|
else:
|
|
copyfile(src_name, dst_name, **kwargs)
|
|
|
|
copystat(src, dst)
|
|
_make_writable(dst)
|
|
|
|
|
|
def copyfile(src: Path, dst: Path, **kwargs):
|
|
if dst.exists():
|
|
return
|
|
dst = Path(Template(str(dst)).render(**kwargs))
|
|
copy2(src, dst)
|
|
_make_writable(dst)
|
|
if dst.suffix != ".tmpl":
|
|
return
|
|
render_template("/", dst, **kwargs)
|
|
|
|
|
|
def remove_files_except(path, file, new_file):
|
|
for filename in os.listdir(path):
|
|
file_path = os.path.join(path, filename)
|
|
if os.path.isfile(file_path) and filename != file:
|
|
os.remove(file_path)
|
|
os.rename(path / file, path / new_file)
|
|
|
|
|
|
def _make_writable(path):
|
|
current_permissions = os.stat(path).st_mode
|
|
os.chmod(path, current_permissions | OWNER_WRITE_PERMISSION)
|
|
|
|
|
|
def escape_single_quotes(s: str):
|
|
return s.replace("'", "\\'")
|
|
|
|
|
|
def load_json(content):
|
|
try:
|
|
return json.loads(content)
|
|
except json.JSONDecodeError as e:
|
|
substr = content[: e.colno - 1]
|
|
return json.loads(substr)
|
|
|
|
|
|
def flatten_2d_list(nested_list):
|
|
return [item for sublist in nested_list for item in sublist]
|
|
|
|
|
|
def split_module_class_name(name: str, text: str) -> Tuple[str, str]:
|
|
"""
|
|
Split `name` as module name and class name pair.
|
|
|
|
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
|
|
:type name: str
|
|
:param text: describe the kind of the class, used in the exception message
|
|
:type text: str
|
|
:rtype: Tuple[str, str]
|
|
:raises RuntimeError: if `name` is not a fully qualified class name
|
|
"""
|
|
i = name.rfind(".")
|
|
if i == -1:
|
|
message = "invalid %s class name: %s" % (text, name)
|
|
raise RuntimeError(message)
|
|
module_name = name[:i]
|
|
class_name = name[i + 1 :]
|
|
return module_name, class_name
|
|
|
|
|
|
def dynamic_import_class(name: str, text: str):
|
|
"""
|
|
Import the class specified by `name` dyanmically.
|
|
|
|
:param name: fully qualified class name, e.g. ``foo.bar.MyClass``
|
|
:type name: str
|
|
:param text: describe the kind of the class, use in the exception message
|
|
:type text: str
|
|
:raises RuntimeError: if `name` is not a fully qualified class name, or
|
|
the class is not in the module specified by `name`
|
|
:raises ModuleNotFoundError: the module specified by `name` is not found
|
|
"""
|
|
module_name, class_name = split_module_class_name(name, text)
|
|
module = importlib.import_module(module_name)
|
|
class_ = getattr(module, class_name, None)
|
|
if class_ is None:
|
|
message = "class %r not found in module %r" % (class_name, module_name)
|
|
raise RuntimeError(message)
|
|
if not isinstance(class_, type):
|
|
message = "%r is not a class" % (name,)
|
|
raise RuntimeError(message)
|
|
return class_
|
|
|
|
|
|
def processing_phrases(phrase):
|
|
phrase = str(phrase)
|
|
return re.sub("[^A-Za-z0-9\u4e00-\u9fa5 ]", " ", phrase.lower()).strip()
|
|
|
|
|
|
def to_camel_case(phrase):
|
|
s = processing_phrases(phrase).replace(" ", "_")
|
|
return "".join(
|
|
word.capitalize() if i != 0 else word for i, word in enumerate(s.split("_"))
|
|
)
|
|
|
|
|
|
def to_snake_case(name):
|
|
words = re.findall("[A-Za-z][a-z0-9]*", name)
|
|
result = "_".join(words).lower()
|
|
return result
|
|
|
|
|
|
def get_vector_field_name(property_key: str):
|
|
name = f"{property_key}_vector"
|
|
name = to_snake_case(name)
|
|
return "_" + name
|
|
|
|
|
|
def get_sparse_vector_field_name(property_key: str):
|
|
name = f"{property_key}_sparse"
|
|
name = to_snake_case(name)
|
|
return "_" + name
|
|
|
|
|
|
def split_list_into_n_parts(lst, n):
|
|
length = len(lst)
|
|
part_size = length // n
|
|
seg = [x * part_size for x in range(n)]
|
|
seg.append(min(length, part_size * n))
|
|
|
|
remainder = length % n
|
|
|
|
result = []
|
|
|
|
# 分割列表
|
|
start = 0
|
|
for i in range(n):
|
|
# 计算当前份的元素数量
|
|
if i < remainder:
|
|
end = start + part_size + 1
|
|
else:
|
|
end = start + part_size
|
|
|
|
# 添加当前份到结果列表
|
|
result.append(lst[start:end])
|
|
|
|
# 更新起始位置
|
|
start = end
|
|
|
|
return result
|
|
|
|
|
|
def generate_hash_id(value):
|
|
"""
|
|
Generates a hash ID and an abstracted version of the input value.
|
|
|
|
If the input value is a dictionary, it sorts the dictionary items and abstracts the dictionary.
|
|
If the input value is not a dictionary, it abstracts the value directly.
|
|
|
|
Args:
|
|
value: The input value to be hashed and abstracted.
|
|
|
|
Returns:
|
|
str: A hash ID generated from the input value.
|
|
"""
|
|
if isinstance(value, dict):
|
|
sorted_items = sorted(value.items())
|
|
key = str(sorted_items)
|
|
else:
|
|
key = str(value) # Ensure key is a string regardless of input type
|
|
|
|
# Encode to bytes for hashing
|
|
key = key.encode("utf-8")
|
|
|
|
hasher = hashlib.sha256()
|
|
hasher.update(key)
|
|
|
|
return hasher.hexdigest()
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), reraise=True)
|
|
def download_from_http(url: str, dest: str = None) -> str:
|
|
"""Downloads a file from an HTTP URL and saves it to a temporary directory.
|
|
|
|
This function uses the requests library to download a file from the specified
|
|
HTTP URL and saves it to the system's temporary directory. After the download
|
|
is complete, it returns the local path of the downloaded file.
|
|
|
|
Args:
|
|
url (str): The HTTP URL of the file to be downloaded.
|
|
|
|
Returns:
|
|
str: The local path of the downloaded file.
|
|
|
|
"""
|
|
|
|
# Send an HTTP GET request to download the file
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status() # Check if the request was successful
|
|
|
|
if dest is None:
|
|
# Create a temporary file
|
|
temp_dir = tempfile.gettempdir()
|
|
temp_file_path = os.path.join(temp_dir, os.path.basename(url))
|
|
dest = temp_file_path
|
|
|
|
with open(dest, "wb") as temp_file:
|
|
# Write the downloaded content to the temporary file
|
|
for chunk in response.iter_content(chunk_size=1024**2):
|
|
temp_file.write(chunk)
|
|
|
|
# Return the path of the temporary file
|
|
return temp_file.name
|
|
|
|
|
|
class RateLimiterManger:
|
|
def __init__(self):
|
|
self.limiter_map = {}
|
|
|
|
def get_rate_limiter(
|
|
self, name: str, max_rate: float = 1000, time_period: float = 1
|
|
):
|
|
if name not in self.limiter_map:
|
|
limiter = AsyncLimiter(max_rate, time_period)
|
|
self.limiter_map[name] = limiter
|
|
return self.limiter_map[name]
|
|
|
|
|
|
def get_now(language="zh"):
|
|
if language == "zh":
|
|
days_of_week = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
|
|
date_format = "%Y年%m月%d日"
|
|
elif language == "en":
|
|
days_of_week = [
|
|
"Monday",
|
|
"Tuesday",
|
|
"Wednesday",
|
|
"Thursday",
|
|
"Friday",
|
|
"Saturday",
|
|
"Sunday",
|
|
]
|
|
date_format = "%Y-%m-%d"
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported language. Please use 'zh' for Chinese or 'en' for English."
|
|
)
|
|
|
|
today = datetime.datetime.now()
|
|
return today.strftime(date_format) + " (" + days_of_week[today.weekday()] + ")"
|
|
|
|
|
|
def generate_random_string(bit=8):
|
|
possible_characters = string.ascii_letters + string.digits
|
|
random_str = "".join(random.choice(possible_characters) for _ in range(bit))
|
|
return "gen" + random_str
|
|
|
|
|
|
def generate_biz_id_with_type(biz_id, type_name):
|
|
return f"{biz_id}_{type_name}"
|
|
|
|
|
|
def get_p_clean(p):
|
|
if re.search(".*[\\u4e00-\\u9fa5]+.*", p):
|
|
p = re.sub("[ \t::()“”‘’'\"\[\]\(\)]+?", "", p)
|
|
else:
|
|
p = None
|
|
return p
|
|
|
|
|
|
def get_recall_node_label(label_set):
|
|
for l in label_set:
|
|
if l != "Entity":
|
|
return l
|
|
return "Entity"
|
|
|
|
|
|
def node_2_doc(node: dict):
|
|
prop_set = []
|
|
for key in node.keys():
|
|
if key in ["id"]:
|
|
continue
|
|
value = node[key]
|
|
if isinstance(value, list):
|
|
value = "\n".join(value)
|
|
else:
|
|
value = str(value)
|
|
if key == "name":
|
|
prop = f"节点名称:{value}"
|
|
elif key == "description":
|
|
prop = f"描述:{value}"
|
|
else:
|
|
prop = f"{key}:{value}"
|
|
prop_set.append(prop)
|
|
return "\n".join(prop_set)
|
|
|
|
|
|
def extract_content_target(input_string):
|
|
"""
|
|
Extract the content and target parts from the input string.
|
|
|
|
Args:
|
|
input_string (str): A string containing content and target.
|
|
|
|
Returns:
|
|
dict: A dictionary containing 'content' and 'target'. If not found, the corresponding value is None.
|
|
"""
|
|
# Define regex patterns
|
|
# Content may contain newlines and special characters, so use non-greedy mode
|
|
content_pattern = r"content=\[(.*?)\]"
|
|
target_pattern = (
|
|
r"target=([^,\]]+)" # Assume target does not contain commas or closing brackets
|
|
)
|
|
|
|
# Search for content
|
|
content_match = re.search(content_pattern, input_string, re.DOTALL)
|
|
if content_match:
|
|
content = content_match.group(1).strip()
|
|
else:
|
|
content = None
|
|
|
|
# Search for target
|
|
target_match = re.search(target_pattern, input_string)
|
|
if target_match:
|
|
target = (
|
|
target_match.group(1).strip().rstrip("'")
|
|
) # Remove trailing single quote if present
|
|
else:
|
|
target = None
|
|
return content, target
|
|
|
|
|
|
def generate_unique_message_key(message):
|
|
unique_id = uuid.uuid5(uuid.NAMESPACE_URL, str(message))
|
|
timestamp = int(time.time() * 1000) # 获取当前时间戳(毫秒级)
|
|
# unique_id = uuid.uuid4().hex # 生成一个UUID并转换为十六进制字符串
|
|
async_message_key = f"KAG_{timestamp}_{unique_id}"
|
|
return async_message_key
|
|
|
|
|
|
def rrf_score(length, r: int = 1):
|
|
return np.array([1 / (r + i) for i in range(length)])
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def resolve_instance(
|
|
instance: Optional[Union[T, dict]],
|
|
default_config: dict,
|
|
from_config_func,
|
|
expected_type=None,
|
|
) -> T:
|
|
if isinstance(instance, dict):
|
|
return from_config_func(instance)
|
|
elif instance is None:
|
|
return from_config_func(default_config)
|
|
elif expected_type and not isinstance(instance, expected_type):
|
|
raise TypeError(f"Expected {expected_type}, got {type(instance)}")
|
|
else:
|
|
return instance
|
|
|
|
|
|
def extract_tag_content(text):
|
|
pattern = r"<(\w+)\b[^>]*>(.*?)</\1>|<(\w+)\b[^>]*>([^<]*)|([^<]+)"
|
|
results = []
|
|
for match in re.finditer(pattern, text, re.DOTALL):
|
|
tag1, content1, tag2, content2, raw_text = match.groups()
|
|
if tag1:
|
|
results.append((tag1, content1)) # 保留原始内容(含空格)
|
|
elif tag2:
|
|
results.append((tag2, content2)) # 保留原始内容(含空格)
|
|
elif raw_text:
|
|
results.append(("", raw_text)) # 保留原始空格
|
|
return results
|
|
|
|
|
|
def extract_specific_tag_content(text, tag):
|
|
# 构建正则表达式:匹配指定标签内的内容(支持嵌套相同标签)
|
|
pattern = rf"<{tag}\b[^>]*>(.*?)</{tag}>"
|
|
matches = re.findall(pattern, text, flags=re.DOTALL)
|
|
return [content.strip() for content in matches]
|
|
|
|
|
|
def extract_box_answer(text):
|
|
pattern = r"\\boxed\{([^}]*)\}"
|
|
extracted_answers = re.findall(pattern, text)
|
|
if len(extracted_answers) == 0:
|
|
return ""
|
|
else:
|
|
return extracted_answers[0]
|
|
|
|
|
|
def remove_boxed(text):
|
|
# 匹配 \boxed{内容} 并提取内容部分
|
|
pattern = r"\\boxed\{([^}]*)\}"
|
|
# 使用正则替换为仅保留大括号中的内容
|
|
result = re.sub(pattern, r"\1", text)
|
|
return result
|
|
|
|
|
|
def search_plan_extraction(text):
|
|
text = text.replace("\n", "")
|
|
pattern = r"(?i)<search.*?>.*?</search>"
|
|
matches = re.findall(pattern, text)
|
|
|
|
# 提取内容部分
|
|
extracted_plans = []
|
|
for match in matches:
|
|
# 使用非贪婪匹配提取内容
|
|
plan = re.search(r"<search.*?>(.*?)</search>", match, re.IGNORECASE).group(1)
|
|
extracted_plans.append(plan)
|
|
if len(extracted_plans) == 0:
|
|
return ""
|
|
else:
|
|
return extracted_plans[0].strip()
|