KAG/kag/solver/main_solver.py
royzhao e1012d39e4
feat(solver): support kag thinker (#640)
* feat(kag): update to v0.7 (#456)

* add think cost

* update csv scanner

* add final rerank

* add reasoner

* add iterative planner

* fix dpr search

* fix dpr search

* add reference data

* move odps import

* update requirement.txt

* update 2wiki

* add missing file

* fix markdown reader

* add iterative planning

* update version

* update runner

* update 2wiki example

* update bridge

* merge solver and solver_new

* add cur day

* writer delete

* update multi process

* add missing files

* fix report

* add chunk retrieved executor

* update try in stream runner result

* add path

* add math executor

* update hotpotqa example

* remove log

* fix python coder solver

* update hotpotqa example

* fix python coder solver

* update config

* fix bad

* add log

* remove unused code

* commit with task thought

* move kag model to common

* add default chat llm

* fix

* use static planner

* support chunk graph node

* add args

* support naive rag

* llm client support tool calls

* add default async

* add openai

* fix result

* fix markdown reader

* fix thinker

* update asyncio interface

* feat(solver): add mcp support (#444)

* 上传mcp client相关代码

* 1、完成一套mcp client的调用,从pipeline到planner、executor
2、允许json中传入多个mcp_server,通过大模型进行调用并选择
3、调通baidu_map_mcp的使用

* 1、schema

* bugfix:删减冗余代码

---------

Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com>

* fix affairqa after solver refactor

* fix affairqa after solver refactor

* fix readme

* add params

* update version

* update mcp executor

* update mcp executor

* solver add mcp executor

* add missing file

* add mpc executor

* add executor

* x

* update

* fix requirement

* fix main llm config

* fix solver

* bugfix:修复invoke函数调用逻辑

* chg eva

* update example

* add kag layer

* add step task

* support dot refresh

* support dot refresh

* support dot refresh

* support dot refresh

* add retrieved num

* add retrieved num

* add pipelineconf

* update ppr

* update musique prompts

* update

* add to_dict for BuilderComponentData

* async build

* add deduce prompt

* add deduce prompt

* add deduce prompt

* fix reader

* add deduce prompt

* add page thinker report

* modify prmpt

* add step status

* add self cognition

* add self cognition

* add memory graph storage

* add now time

* update memory config

* add now time

* chg graph loader

* 添加prqa数据集和代码

* bugfix:prqa调用逻辑修复

* optimize:优化代码逻辑,生成答案规范化

* add retry py code

* update memory graph

* update memory graph

* fix

* fix ner

* add with_out_refer generator prompt

* fix

* close ckpt

* fix query

* fix query

* update version

* add llm checker

* add llm checker

* 1、上传evalutor.py以及修改gold_answer.json格式
2、优化代码逻辑
3、修改README.md文件

* update exp

* update exp

* rerank support

* add static rewrite query

* recall more chunks

* fix graph load

* add static rewrite query

* fix bugs

* add finish check

* add finish check

* add finish check

* add finish check

* 1、上传evalutor.py的结果
2、优化代码逻辑,优化readme文件

* add lf retry

* add memory graph api

* fix reader api

* add ner

* add metrics

* fix bug

* remove ner

* add reraise fo retry

* add edge prop to memory graph

* add memory graph

* 1、评测数据集结果修正
2、优化evaluator.py代码
3、删除结果不存在而gold_answer中有答案的问题

* 删除评测结果文件

* fix knext host addr

* async eva

* add lf prompt

* add lf prompt

* add config

* add retry

* add unknown check

* add rc result

* add rc result

* add rc result

* add rc result

* 依据kag pipeline格式修改代码逻辑并通过测试

* bugfix:删除冗余代码

* fix report prompt

* bugfix:触发重试机制

* bugfix:中文符号错误

* fix rethinker prompt

* update version to 0.6.2b78

* update version

* 1、修改evaluator.py,通过大模型计算准确率,符合最新调用逻辑
2、修改prompt,让没有回答的结果重复测试

* update affairqa for evaluate

* update affairqa for evaluate

* bugfix:修正数据集

* bugfix:修正数据集

* bugfix:修正数据集

* fix name conflict

* bugfix:删除错误问题

* bugfix:文件名命名错误导致evaluator失败

* update for affairqa eval

* bugfix:修改代码保持evaluate逻辑一致

* x

* update for affairqa readme

* remove temp eval scripts

* bugfix for math deduce

* merge 0.6.2_dev

* merge 0.6.2_dev

* fix

* update client addr

* updated version

* update for affairqa eval

* evaUtils 支持中文

* fix affairqa eval:

* remove unused example

* update kag config

* fix default value

* update readme

* fix init

* 注释信息修改,并添加部分class说明

* update example config

* Tc 0.7.0 (#459)

* 提交affairQA 代码

* fix affairqa eval

---------

Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>

* fix all examples

* reformat

---------

Co-authored-by: peilong <peilong.zpl@antgroup.com>
Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com>
Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com>
Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>

* update chunk metadata

* update chunk metadata

* add debug reporter

* update table text

* add server

* fix math executor

* update api-key for openai vec

* update

* fix naive rag bug

* format code

* fix

---------

Co-authored-by: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com>
Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com>
Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com>
Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
2025-07-08 17:44:32 +08:00

430 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import asyncio
import copy
import json
import logging
import os
import re
import yaml
from kag.common.conf import KAGConfigMgr, KAGConfigAccessor, KAGConstants
from kag.indexer import KAGIndexManager
from kag.interface import SolverPipelineABC
from knext.project.client import ProjectClient
from kag.common.conf import KAG_CONFIG, KAG_PROJECT_CONF
from kag.interface.solver.reporter_abc import ReporterABC
logger = logging.getLogger()
def get_all_placeholders(config, placeholders):
if isinstance(config, dict):
for key, value in config.items():
get_all_placeholders(value, placeholders)
elif isinstance(config, list):
return [get_all_placeholders(item, placeholders) for item in config]
elif isinstance(config, str):
if config.startswith("{") and config.endswith("}"):
placeholder = config[1:-1] # 去掉花括号
placeholders.append(placeholder)
return config
else:
return config
def replace_placeholders(config, replacements):
if isinstance(config, dict):
return {
key: replace_placeholders(value, replacements)
for key, value in config.items()
}
elif isinstance(config, list):
return [replace_placeholders(item, replacements) for item in config]
elif isinstance(config, str):
if config.startswith("{") and config.endswith("}"):
placeholder = config[1:-1] # 去掉花括号
if placeholder in replacements:
return replacements[placeholder]
else:
raise RuntimeError(f"Placeholder '{placeholder}' not found in config.")
return config
else:
return config
def load_yaml_files_from_conf_dir():
current_dir = os.path.dirname(os.path.abspath(__file__))
conf_dir = os.path.join(current_dir, "pipelineconf")
if not os.path.exists(conf_dir) or not os.path.isdir(conf_dir):
raise FileNotFoundError(f"The 'conf' directory does not exist at {conf_dir}")
yaml_data = {}
for filename in os.listdir(conf_dir):
if filename.endswith(".yml") or filename.endswith(".yaml"):
file_path = os.path.join(conf_dir, filename)
with open(file_path, "r", encoding="utf-8") as file:
yaml_content = yaml.safe_load(file)
yaml_data[yaml_content["pipeline_name"]] = yaml_content
return yaml_data
def get_pipeline_conf(use_pipeline_name, config):
pipeline_name = "solver_pipeline"
conf_map = load_yaml_files_from_conf_dir()
if use_pipeline_name not in conf_map:
raise RuntimeError(
f"Pipeline configuration not found for pipeline_name: {use_pipeline_name}"
)
placeholders = []
get_all_placeholders(conf_map[use_pipeline_name], placeholders)
placeholders = list(set(placeholders))
placeholders_replacement_map = {}
for placeholder in placeholders:
value = config.get(placeholder)
backup_key = None
if value is None:
if "llm" in placeholder:
backup_key = "llm"
if "vectorizer" in placeholder:
backup_key = "vectorizer"
if backup_key:
value = config.get(backup_key)
if value is None:
raise RuntimeError(
f"Placeholder '{placeholder}' '{'or ' + backup_key if backup_key else ''}' not found in config."
)
if "llm" in placeholder or "vectorizer" in placeholder:
value["enable_check"] = False
placeholders_replacement_map[placeholder] = value
default_pipeline_conf = replace_placeholders(
conf_map[use_pipeline_name], placeholders_replacement_map
)
default_solver_pipeline = default_pipeline_conf[pipeline_name]
if use_pipeline_name == "mcp_pipeline":
mcp_servers = config["kb"][0]["mcp_servers"]
logger.info(f"mcp_servers = {mcp_servers}")
logger.info(f"config = {config}")
mcp_executors = []
if mcp_servers is not None:
for mcp_name, mcp_conf in mcp_servers.items():
desc = mcp_conf.get("description", "")
env = mcp_conf.get("env", {})
store_path = mcp_conf.get("store_path", "")
mcp_executors.append(
{
"type": "mcp_executor",
"store_path": store_path,
"name": mcp_name,
"description": desc,
"env": env,
"llm": config.get("llm"),
}
)
else:
raise RuntimeError("mcpServers not found in config.")
default_solver_pipeline["executors"] = mcp_executors
return default_solver_pipeline
def is_chinese(text):
chinese_pattern = re.compile(r"[\u4e00-\u9fff]+")
return bool(chinese_pattern.search(text))
async def do_qa_pipeline(
use_pipeline, query, qa_config, reporter, task_id, kb_project_ids
):
retriever_configs = []
kb_configs = qa_config.get("kb", [])
for kb_project_id in kb_project_ids:
kb_task_project_id = f"{task_id}_{kb_project_id}"
try:
kag_config = KAGConfigAccessor.get_config(kb_task_project_id)
matched_kb = next(
(kb for kb in kb_configs if kb.get("id") == kb_project_id), None
)
if not matched_kb:
reporter.warning(
f"Knowledge base with id {kb_project_id} not found in qa_config['kb']"
)
continue
index_list = matched_kb.get("index_list", [])
if use_pipeline in ["default_pipeline"]:
# we only use chunk index
index_list = ["chunk_index"]
for index_name in index_list:
index_manager = KAGIndexManager.from_config(
{
"type": index_name,
"llm_config": qa_config.get("llm", {}),
"vectorize_model_config": kag_config.all_config.get(
"vectorize_model", {}
),
}
)
retriever_configs.extend(
index_manager.build_retriever_config(
qa_config.get("llm", {}),
kag_config.all_config.get("vectorize_model", {}),
kag_qa_task_config_key=kb_task_project_id,
)
)
except Exception as e:
logger.error(f"Error processing kb_project_id {kb_project_id}: {str(e)}")
continue
qa_config["retrievers"] = retriever_configs
if use_pipeline in qa_config.keys():
custom_pipeline_conf = copy.deepcopy(qa_config.get(use_pipeline, None))
else:
custom_pipeline_conf = copy.deepcopy(qa_config.get("solver_pipeline", None))
if use_pipeline not in ["index_pipeline"]:
self_cognition_conf = get_pipeline_conf("self_cognition_pipeline", qa_config)
self_cognition_pipeline = SolverPipelineABC.from_config(self_cognition_conf)
self_cognition_res = await self_cognition_pipeline.ainvoke(
query, reporter=reporter
)
else:
self_cognition_res = False
if not self_cognition_res:
if custom_pipeline_conf:
pipeline_config = custom_pipeline_conf
else:
pipeline_config = get_pipeline_conf(use_pipeline, qa_config)
logger.error(f"pipeline conf: \n{pipeline_config}")
pipeline = SolverPipelineABC.from_config(pipeline_config)
answer = await pipeline.ainvoke(query, reporter=reporter)
else:
answer = self_cognition_res
return answer
async def qa(task_id, query, project_id, host_addr, app_id, params={}):
main_config = params.get("config", KAGConfigAccessor.get_config().all_config)
if isinstance(main_config, str):
main_config = json.loads(main_config)
KAG_PROJECT_CONF.host_addr = host_addr
KAG_PROJECT_CONF.language = "zh" if is_chinese(query) else "en"
use_pipeline = (
main_config["chat"]["ename"]
if isinstance(main_config.get("chat"), dict)
else params.get("usePipeline", "think_pipeline")
)
# process llm
if "extra_body" in main_config["llm"] and main_config["llm"]["type"] in [
"openai",
"ant_openai",
"maas",
"vllm",
]:
extra_body = main_config["llm"]["extra_body"]
if isinstance(extra_body, str):
try:
extra_body_json = json.loads(extra_body)
except:
extra_body_json = {}
main_config["llm"]["extra_body"] = extra_body_json
kb_configs = {}
kb_project_ids = []
vectorize_model = {}
global_index_set = main_config.get("chat", {}).get("index_list", [])
if isinstance(main_config.get("kb"), list):
kbs = main_config["kb"]
for kb in kbs:
try:
kb_project_id = kb.get("id") or kb.get("project", {}).get("id")
if not kb_project_id:
continue
kb_project_ids.append(kb_project_id)
kb_task_project_id = f"{task_id}_{kb_project_id}"
kb_conf = KAGConfigMgr()
kb_conf.update_conf(kb)
global_config = kb.get(KAGConstants.PROJECT_CONFIG_KEY, {})
kb_conf.global_config.initialize(**global_config)
project_client = ProjectClient(
host_addr=host_addr, project_id=kb_project_id
)
project = project_client.get_by_id(kb_project_id)
kb_conf.global_config.project_id = kb_project_id
kb_conf.global_config.namespace = project.namespace
kb_conf.global_config.host_addr = host_addr
kb_conf.global_config.language = KAG_PROJECT_CONF.language
if "llm" in main_config:
kb_conf.update_conf({"llm": main_config["llm"]})
if "vectorizer" in kb:
kb_conf.update_conf({"vectorize_model": kb["vectorizer"]})
vectorize_model = kb["vectorizer"]
if "index_list" not in kb and global_index_set:
kb["index_list"] = global_index_set
KAGConfigAccessor.set_task_config(kb_task_project_id, kb_conf)
kb_configs[kb_project_id] = (kb_task_project_id, kb_conf)
except Exception as e:
logger.error(f"KB配置初始化失败: {str(e)}", exc_info=True)
if "vectorize_model" not in main_config.keys():
main_config["vectorize_model"] = vectorize_model
if vectorize_model:
KAG_CONFIG.update_conf({"vectorize_model": vectorize_model})
if main_config["llm"]:
KAG_CONFIG.update_conf({"llm": main_config["llm"]})
reporter_map = {"kag_thinker_pipeline": "kag_open_spg_reporter"}
reporter_config = {
"type": reporter_map.get(use_pipeline, "open_spg_reporter"),
"task_id": task_id,
"host_addr": host_addr,
"project_id": project_id,
"thinking_enabled": use_pipeline
in ["think_pipeline", "index_pipeline", "kag_thinker_pipeline"],
"report_all_references": use_pipeline == "index_pipeline",
}
reporter = ReporterABC.from_config(reporter_config)
try:
await reporter.start()
answer = await do_qa_pipeline(
use_pipeline,
query,
main_config,
reporter,
task_id=task_id,
kb_project_ids=kb_project_ids,
)
reporter.add_report_line("answer", "Final Answer", answer, "FINISH")
except Exception as e:
logger.warning(
f"An exception occurred while processing query: {query}. Error: {str(e)}",
exc_info=True,
)
if is_chinese(query):
answer = f"抱歉,处理查询 {query} 时发生异常。错误:{str(e)}, 请重试。"
else:
answer = f"Sorry, An exception occurred while processing query: {query}. Error: {str(e)}, please retry."
reporter.add_report_line("answer", "Final Answer", answer, "ERROR")
finally:
await reporter.stop()
return answer
class SolverMain:
def invoke(
self,
project_id: int,
task_id,
query: str,
session_id: str = "0",
is_report=True,
host_addr="http://127.0.0.1:8887",
params=None,
app_id="",
):
answer = None
if params is None:
params = {}
try:
answer = asyncio.run(
qa(
task_id=task_id,
project_id=project_id,
host_addr=host_addr,
query=query,
params=params,
app_id=app_id,
)
)
logger.info(f"{query} answer={answer}")
except Exception as e:
import traceback
traceback.print_exc()
logger.warning(
f"An exception occurred while processing query: {query}. Error: {str(e)}",
exc_info=True,
)
return answer
async def ainvoke(
self,
project_id: int,
task_id: int,
query: str,
session_id: str = "0",
is_report=True,
host_addr="http://127.0.0.1:8887",
params=None,
app_id="",
):
answer = None
if params is None:
params = {}
try:
answer = await qa(
task_id=task_id,
project_id=project_id,
host_addr=host_addr,
query=query,
params=params,
app_id=app_id,
)
logger.info(f"{query} answer={answer}")
except Exception as e:
import traceback
traceback.print_exc()
logger.warning(
f"An exception occurred while processing query: {query}. Error: {str(e)}",
exc_info=True,
)
return answer
if __name__ == "__main__":
# init_kag_config(
# "4200052", "https://spg-pre.alipay.com"
# )
config = {}
params = {"config": config}
res = SolverMain().invoke(
2100007,
11200009,
# "阿里巴巴2024年截止到9月30日的总收入是多少元 如果把这笔钱于当年10月3日存入银行并于12月29日取出银行日利息是万分之0.9,本息共可取出多少元?",
"营业执照不通过",
"9500005",
True,
host_addr="http://spg-pre.alipay.com",
# host_addr="http://antspg-gz00b-006001164035.sa128-sqa.alipay.net:8887",
params=params,
)
print("*" * 80)
print("The Answer is: ", res)