mirror of
				https://github.com/OpenSPG/KAG.git
				synced 2025-11-03 19:45:17 +00:00 
			
		
		
		
	feat(solver): support kag thinker (#640)
* feat(kag): update to v0.7 (#456) * add think cost * update csv scanner * add final rerank * add reasoner * add iterative planner * fix dpr search * fix dpr search * add reference data * move odps import * update requirement.txt * update 2wiki * add missing file * fix markdown reader * add iterative planning * update version * update runner * update 2wiki example * update bridge * merge solver and solver_new * add cur day * writer delete * update multi process * add missing files * fix report * add chunk retrieved executor * update try in stream runner result * add path * add math executor * update hotpotqa example * remove log * fix python coder solver * update hotpotqa example * fix python coder solver * update config * fix bad * add log * remove unused code * commit with task thought * move kag model to common * add default chat llm * fix * use static planner * support chunk graph node * add args * support naive rag * llm client support tool calls * add default async * add openai * fix result * fix markdown reader * fix thinker * update asyncio interface * feat(solver): add mcp support (#444) * 上传mcp client相关代码 * 1、完成一套mcp client的调用,从pipeline到planner、executor 2、允许json中传入多个mcp_server,通过大模型进行调用并选择 3、调通baidu_map_mcp的使用 * 1、schema * bugfix:删减冗余代码 --------- Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> * fix affairqa after solver refactor * fix affairqa after solver refactor * fix readme * add params * update version * update mcp executor * update mcp executor * solver add mcp executor * add missing file * add mpc executor * add executor * x * update * fix requirement * fix main llm config * fix solver * bugfix:修复invoke函数调用逻辑 * chg eva * update example * add kag layer * add step task * support dot refresh * support dot refresh * support dot refresh * support dot refresh * add retrieved num * add retrieved num * add pipelineconf * update ppr * update musique prompts * update * add to_dict for BuilderComponentData * async build * add deduce prompt * add deduce prompt * add deduce prompt * fix reader * add deduce prompt * add page thinker report * modify prmpt * add step status * add self cognition * add self cognition * add memory graph storage * add now time * update memory config * add now time * chg graph loader * 添加prqa数据集和代码 * bugfix:prqa调用逻辑修复 * optimize:优化代码逻辑,生成答案规范化 * add retry py code * update memory graph * update memory graph * fix * fix ner * add with_out_refer generator prompt * fix * close ckpt * fix query * fix query * update version * add llm checker * add llm checker * 1、上传evalutor.py以及修改gold_answer.json格式 2、优化代码逻辑 3、修改README.md文件 * update exp * update exp * rerank support * add static rewrite query * recall more chunks * fix graph load * add static rewrite query * fix bugs * add finish check * add finish check * add finish check * add finish check * 1、上传evalutor.py的结果 2、优化代码逻辑,优化readme文件 * add lf retry * add memory graph api * fix reader api * add ner * add metrics * fix bug * remove ner * add reraise fo retry * add edge prop to memory graph * add memory graph * 1、评测数据集结果修正 2、优化evaluator.py代码 3、删除结果不存在而gold_answer中有答案的问题 * 删除评测结果文件 * fix knext host addr * async eva * add lf prompt * add lf prompt * add config * add retry * add unknown check * add rc result * add rc result * add rc result * add rc result * 依据kag pipeline格式修改代码逻辑并通过测试 * bugfix:删除冗余代码 * fix report prompt * bugfix:触发重试机制 * bugfix:中文符号错误 * fix rethinker prompt * update version to 0.6.2b78 * update version * 1、修改evaluator.py,通过大模型计算准确率,符合最新调用逻辑 2、修改prompt,让没有回答的结果重复测试 * update affairqa for evaluate * update affairqa for evaluate * bugfix:修正数据集 * bugfix:修正数据集 * bugfix:修正数据集 * fix name conflict * bugfix:删除错误问题 * bugfix:文件名命名错误导致evaluator失败 * update for affairqa eval * bugfix:修改代码保持evaluate逻辑一致 * x * update for affairqa readme * remove temp eval scripts * bugfix for math deduce * merge 0.6.2_dev * merge 0.6.2_dev * fix * update client addr * updated version * update for affairqa eval * evaUtils 支持中文 * fix affairqa eval: * remove unused example * update kag config * fix default value * update readme * fix init * 注释信息修改,并添加部分class说明 * update example config * Tc 0.7.0 (#459) * 提交affairQA 代码 * fix affairqa eval --------- Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * fix all examples * reformat --------- Co-authored-by: peilong <peilong.zpl@antgroup.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * update chunk metadata * update chunk metadata * add debug reporter * update table text * add server * fix math executor * update api-key for openai vec * update * fix naive rag bug * format code * fix --------- Co-authored-by: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
This commit is contained in:
		
							parent
							
								
									9b2d894295
								
							
						
					
					
						commit
						e1012d39e4
					
				@ -1 +1 @@
 | 
			
		||||
0.8.0
 | 
			
		||||
0.8.0
 | 
			
		||||
@ -463,9 +463,17 @@ def resolve_instance(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_tag_content(text):
 | 
			
		||||
    # 匹配<tag>和</tag>之间的内容,支持任意标签名
 | 
			
		||||
    matches = re.findall(r"<([^>]+)>(.*?)</\1>", text, flags=re.DOTALL)
 | 
			
		||||
    return [(tag, content.strip()) for tag, content in matches]
 | 
			
		||||
    pattern = r"<(\w+)\b[^>]*>(.*?)</\1>|<(\w+)\b[^>]*>([^<]*)|([^<]+)"
 | 
			
		||||
    results = []
 | 
			
		||||
    for match in re.finditer(pattern, text, re.DOTALL):
 | 
			
		||||
        tag1, content1, tag2, content2, raw_text = match.groups()
 | 
			
		||||
        if tag1:
 | 
			
		||||
            results.append((tag1, content1))  # 保留原始内容(含空格)
 | 
			
		||||
        elif tag2:
 | 
			
		||||
            results.append((tag2, content2))  # 保留原始内容(含空格)
 | 
			
		||||
        elif raw_text:
 | 
			
		||||
            results.append(("", raw_text))  # 保留原始空格
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_specific_tag_content(text, tag):
 | 
			
		||||
 | 
			
		||||
@ -131,9 +131,11 @@ class PyBasedMathExecutor(ExecutorABC):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        parent_results = format_task_dep_context(task.parents)
 | 
			
		||||
        parent_results = "\n".join(parent_results)
 | 
			
		||||
        coder_content = context.kwargs.get("planner_thought", "") + "\n\n".join(
 | 
			
		||||
            parent_results
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        parent_results += "\n\n" + contents
 | 
			
		||||
        coder_content += "\n\n" + contents
 | 
			
		||||
        tries = self.tries
 | 
			
		||||
        error = None
 | 
			
		||||
 | 
			
		||||
@ -141,7 +143,7 @@ class PyBasedMathExecutor(ExecutorABC):
 | 
			
		||||
            tries -= 1
 | 
			
		||||
            rst, error, code = self.run_once(
 | 
			
		||||
                math_query,
 | 
			
		||||
                parent_results,
 | 
			
		||||
                coder_content,
 | 
			
		||||
                error,
 | 
			
		||||
                segment_name=tag_id,
 | 
			
		||||
                tag_name=f"{task_query}_code_generator",
 | 
			
		||||
 | 
			
		||||
@ -42,6 +42,15 @@ from kag.solver.utils import init_prompt_with_fallback
 | 
			
		||||
logger = logging.getLogger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _wrapped_invoke(retriever, task, context, segment_name, kwargs):
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    output = retriever.invoke(
 | 
			
		||||
        task, context=context, segment_name=segment_name, **kwargs
 | 
			
		||||
    )
 | 
			
		||||
    elapsed_time = time.time() - start_time
 | 
			
		||||
    return output, elapsed_time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ExecutorABC.register("kag_hybrid_retrieval_executor")
 | 
			
		||||
class KAGHybridRetrievalExecutor(ExecutorABC):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -76,6 +85,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
 | 
			
		||||
        self.context_select_prompt = context_select_prompt or PromptABC.from_config(
 | 
			
		||||
            {"type": "context_select_prompt"}
 | 
			
		||||
        )
 | 
			
		||||
        self.with_llm_select = kwargs.get("with_llm_select", True)
 | 
			
		||||
 | 
			
		||||
    @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1))
 | 
			
		||||
    def context_select_call(self, variables):
 | 
			
		||||
@ -152,22 +162,30 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
 | 
			
		||||
                        "FINISH",
 | 
			
		||||
                        component_name=retriever.name,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    # Record start time before submitting the task
 | 
			
		||||
                    start_time = time.time()
 | 
			
		||||
                    # Prepare function and submit to thread pool
 | 
			
		||||
                    func = partial(
 | 
			
		||||
                        retriever.invoke,
 | 
			
		||||
                        _wrapped_invoke,
 | 
			
		||||
                        retriever,
 | 
			
		||||
                        task,
 | 
			
		||||
                        context=context,
 | 
			
		||||
                        segment_name=tag_id,
 | 
			
		||||
                        **kwargs,
 | 
			
		||||
                        context,
 | 
			
		||||
                        tag_id,
 | 
			
		||||
                        kwargs.copy(),
 | 
			
		||||
                    )
 | 
			
		||||
                    future = executor.submit(func)
 | 
			
		||||
                    # Save future, retriever, and start_time together
 | 
			
		||||
                    futures.append((future, retriever))
 | 
			
		||||
 | 
			
		||||
                # Collect results from each future
 | 
			
		||||
                for future, retriever in futures:
 | 
			
		||||
                    try:
 | 
			
		||||
                        output = future.result()  # Wait for result
 | 
			
		||||
                        output, elapsed_time = future.result()  # Wait for result
 | 
			
		||||
 | 
			
		||||
                        # Log the elapsed time for this retriever
 | 
			
		||||
                        logger.info(
 | 
			
		||||
                            f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds"
 | 
			
		||||
                        )
 | 
			
		||||
                        outputs.append(output)
 | 
			
		||||
 | 
			
		||||
                        # Log data report after successful execution
 | 
			
		||||
@ -241,13 +259,18 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
 | 
			
		||||
        selected_rel = list(set(selected_rel))
 | 
			
		||||
        formatted_docs = [str(rel) for rel in selected_rel]
 | 
			
		||||
        if retrieved_data.chunks:
 | 
			
		||||
            try:
 | 
			
		||||
                selected_chunks = self.context_select(task_query, retrieved_data.chunks)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    f"select context failed {e}, we use default top 10 to summary",
 | 
			
		||||
                    exc_info=True,
 | 
			
		||||
                )
 | 
			
		||||
            if self.with_llm_select:
 | 
			
		||||
                try:
 | 
			
		||||
                    selected_chunks = self.context_select(
 | 
			
		||||
                        task_query, retrieved_data.chunks
 | 
			
		||||
                    )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        f"select context failed {e}, we use default top 10 to summary",
 | 
			
		||||
                        exc_info=True,
 | 
			
		||||
                    )
 | 
			
		||||
                    selected_chunks = retrieved_data.chunks[:10]
 | 
			
		||||
            else:
 | 
			
		||||
                selected_chunks = retrieved_data.chunks[:10]
 | 
			
		||||
            for doc in selected_chunks:
 | 
			
		||||
                formatted_docs.append(f"{doc.content}")
 | 
			
		||||
@ -280,68 +303,81 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
 | 
			
		||||
        task_query = task.arguments["query"]
 | 
			
		||||
 | 
			
		||||
        tag_id = f"{task_query}_begin_task"
 | 
			
		||||
        self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name)
 | 
			
		||||
        self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name)
 | 
			
		||||
        try:
 | 
			
		||||
            retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
 | 
			
		||||
            retrieved_data = RetrieverOutput(
 | 
			
		||||
                retriever_method=self.schema().get("name", ""), err_msg=str(e)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.report_content(
 | 
			
		||||
            reporter,
 | 
			
		||||
            "reference",
 | 
			
		||||
            f"{task_query}_kag_retriever_result",
 | 
			
		||||
            retrieved_data,
 | 
			
		||||
            "FINISH",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        retrieved_data.task = task
 | 
			
		||||
        logical_node = task.arguments.get("logic_form_node", None)
 | 
			
		||||
        if (
 | 
			
		||||
            logical_node
 | 
			
		||||
            and isinstance(logical_node, GetSPONode)
 | 
			
		||||
            and retrieved_data.summary
 | 
			
		||||
        ):
 | 
			
		||||
            if isinstance(retrieved_data.summary, str):
 | 
			
		||||
                target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
 | 
			
		||||
                s_entities = context.variables_graph.get_entity_by_alias(
 | 
			
		||||
                    logical_node.s.alias_name
 | 
			
		||||
            try:
 | 
			
		||||
                retrieved_data = self.do_main(
 | 
			
		||||
                    task_query, tag_id, task, context, **kwargs
 | 
			
		||||
                )
 | 
			
		||||
                if (
 | 
			
		||||
                    not s_entities
 | 
			
		||||
                    and not logical_node.s.get_mention_name()
 | 
			
		||||
                    and isinstance(logical_node.s, SPOEntity)
 | 
			
		||||
                ):
 | 
			
		||||
                    logical_node.s.entity_name = target_answer
 | 
			
		||||
                    context.kwargs[logical_node.s.alias_name] = logical_node.s
 | 
			
		||||
                o_entities = context.variables_graph.get_entity_by_alias(
 | 
			
		||||
                    logical_node.o.alias_name
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
 | 
			
		||||
                retrieved_data = RetrieverOutput(
 | 
			
		||||
                    retriever_method=self.schema().get("name", ""), err_msg=str(e)
 | 
			
		||||
                )
 | 
			
		||||
                if (
 | 
			
		||||
                    not o_entities
 | 
			
		||||
                    and not logical_node.o.get_mention_name()
 | 
			
		||||
                    and isinstance(logical_node.o, SPOEntity)
 | 
			
		||||
                ):
 | 
			
		||||
                    logical_node.o.entity_name = target_answer
 | 
			
		||||
                    context.kwargs[logical_node.o.alias_name] = logical_node.o
 | 
			
		||||
 | 
			
		||||
            context.variables_graph.add_answered_alias(
 | 
			
		||||
                logical_node.s.alias_name.alias_name, retrieved_data.summary
 | 
			
		||||
            )
 | 
			
		||||
            context.variables_graph.add_answered_alias(
 | 
			
		||||
                logical_node.p.alias_name.alias_name, retrieved_data.summary
 | 
			
		||||
            )
 | 
			
		||||
            context.variables_graph.add_answered_alias(
 | 
			
		||||
                logical_node.o.alias_name.alias_name, retrieved_data.summary
 | 
			
		||||
            self.report_content(
 | 
			
		||||
                reporter,
 | 
			
		||||
                "reference",
 | 
			
		||||
                f"{task_query}_kag_retriever_result",
 | 
			
		||||
                retrieved_data,
 | 
			
		||||
                "FINISH",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        task.update_result(retrieved_data)
 | 
			
		||||
        logger.debug(
 | 
			
		||||
            f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
 | 
			
		||||
        )
 | 
			
		||||
        return retrieved_data
 | 
			
		||||
            retrieved_data.task = task
 | 
			
		||||
            logical_node = task.arguments.get("logic_form_node", None)
 | 
			
		||||
            if (
 | 
			
		||||
                logical_node
 | 
			
		||||
                and isinstance(logical_node, GetSPONode)
 | 
			
		||||
                and retrieved_data.summary
 | 
			
		||||
            ):
 | 
			
		||||
                if isinstance(retrieved_data.summary, str):
 | 
			
		||||
                    target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
 | 
			
		||||
                    s_entities = context.variables_graph.get_entity_by_alias(
 | 
			
		||||
                        logical_node.s.alias_name
 | 
			
		||||
                    )
 | 
			
		||||
                    if (
 | 
			
		||||
                        not s_entities
 | 
			
		||||
                        and not logical_node.s.get_mention_name()
 | 
			
		||||
                        and isinstance(logical_node.s, SPOEntity)
 | 
			
		||||
                    ):
 | 
			
		||||
                        logical_node.s.entity_name = target_answer
 | 
			
		||||
                        context.kwargs[logical_node.s.alias_name] = logical_node.s
 | 
			
		||||
                    o_entities = context.variables_graph.get_entity_by_alias(
 | 
			
		||||
                        logical_node.o.alias_name
 | 
			
		||||
                    )
 | 
			
		||||
                    if (
 | 
			
		||||
                        not o_entities
 | 
			
		||||
                        and not logical_node.o.get_mention_name()
 | 
			
		||||
                        and isinstance(logical_node.o, SPOEntity)
 | 
			
		||||
                    ):
 | 
			
		||||
                        logical_node.o.entity_name = target_answer
 | 
			
		||||
                        context.kwargs[logical_node.o.alias_name] = logical_node.o
 | 
			
		||||
 | 
			
		||||
                context.variables_graph.add_answered_alias(
 | 
			
		||||
                    logical_node.s.alias_name.alias_name, retrieved_data.summary
 | 
			
		||||
                )
 | 
			
		||||
                context.variables_graph.add_answered_alias(
 | 
			
		||||
                    logical_node.p.alias_name.alias_name, retrieved_data.summary
 | 
			
		||||
                )
 | 
			
		||||
                context.variables_graph.add_answered_alias(
 | 
			
		||||
                    logical_node.o.alias_name.alias_name, retrieved_data.summary
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            task.update_result(retrieved_data)
 | 
			
		||||
            logger.debug(
 | 
			
		||||
                f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
 | 
			
		||||
            )
 | 
			
		||||
            return retrieved_data
 | 
			
		||||
        finally:
 | 
			
		||||
            self.report_content(
 | 
			
		||||
                reporter,
 | 
			
		||||
                "thinker",
 | 
			
		||||
                tag_id,
 | 
			
		||||
                "",
 | 
			
		||||
                "FINISH",
 | 
			
		||||
                step=task.name,
 | 
			
		||||
                overwrite=False,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def schema(self) -> dict:
 | 
			
		||||
        """Function schema definition for OpenAI Function Calling
 | 
			
		||||
@ -403,7 +439,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
 | 
			
		||||
                node_type=chunk.properties.get("__labels__"),
 | 
			
		||||
            )
 | 
			
		||||
            entity_prop = dict(chunk.properties) if chunk.properties else {}
 | 
			
		||||
            entity_prop["content"] = chunk.content
 | 
			
		||||
            entity_prop["content"] = f"{chunk.content[:10]}..."
 | 
			
		||||
            entity_prop["score"] = chunk.score
 | 
			
		||||
            entity.prop = Prop.from_dict(entity_prop, "Chunk", None)
 | 
			
		||||
            chunk_graph.append(entity)
 | 
			
		||||
 | 
			
		||||
@ -140,8 +140,6 @@ def get_pipeline_conf(use_pipeline_name, config):
 | 
			
		||||
            raise RuntimeError("mcpServers not found in config.")
 | 
			
		||||
        default_solver_pipeline["executors"] = mcp_executors
 | 
			
		||||
 | 
			
		||||
    # update KAG_CONFIG
 | 
			
		||||
    KAG_CONFIG.update_conf(default_pipeline_conf)
 | 
			
		||||
    return default_solver_pipeline
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -167,8 +165,11 @@ async def do_qa_pipeline(
 | 
			
		||||
                    f"Knowledge base with id {kb_project_id} not found in qa_config['kb']"
 | 
			
		||||
                )
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            for index_name in matched_kb.get("index_list", []):
 | 
			
		||||
            index_list = matched_kb.get("index_list", [])
 | 
			
		||||
            if use_pipeline in ["default_pipeline"]:
 | 
			
		||||
                # we only use chunk index
 | 
			
		||||
                index_list = ["chunk_index"]
 | 
			
		||||
            for index_name in index_list:
 | 
			
		||||
                index_manager = KAGIndexManager.from_config(
 | 
			
		||||
                    {
 | 
			
		||||
                        "type": index_name,
 | 
			
		||||
@ -339,7 +340,7 @@ class SolverMain:
 | 
			
		||||
    def invoke(
 | 
			
		||||
        self,
 | 
			
		||||
        project_id: int,
 | 
			
		||||
        task_id: int,
 | 
			
		||||
        task_id,
 | 
			
		||||
        query: str,
 | 
			
		||||
        session_id: str = "0",
 | 
			
		||||
        is_report=True,
 | 
			
		||||
 | 
			
		||||
@ -3,20 +3,17 @@ pipeline_name: default_pipeline
 | 
			
		||||
 | 
			
		||||
#------------kag-solver configuration start----------------#
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
chunk_retrieved_executor: &chunk_retrieved_executor_conf
 | 
			
		||||
  type: chunk_retrieved_executor
 | 
			
		||||
  top_k: 10
 | 
			
		||||
  retriever:
 | 
			
		||||
    type: vector_chunk_retriever
 | 
			
		||||
    score_threshold: 0.65
 | 
			
		||||
    vectorize_model: "{vectorize_model}"
 | 
			
		||||
 | 
			
		||||
kag_retriever_executor: &kag_retriever_executor_conf
 | 
			
		||||
  type: kag_hybrid_retrieval_executor
 | 
			
		||||
  retrievers: "{retrievers}"
 | 
			
		||||
  merger:
 | 
			
		||||
    type: kag_merger
 | 
			
		||||
  enable_summary:  false
 | 
			
		||||
 | 
			
		||||
solver_pipeline:
 | 
			
		||||
  type: naive_rag_pipeline
 | 
			
		||||
  executors:
 | 
			
		||||
    - *chunk_retrieved_executor_conf
 | 
			
		||||
    - *kag_retriever_executor_conf
 | 
			
		||||
  generator:
 | 
			
		||||
    type: llm_index_generator
 | 
			
		||||
    llm_client: "{chat_llm}"
 | 
			
		||||
 | 
			
		||||
@ -186,6 +186,7 @@ class KAGModelPlanner(PlannerABC):
 | 
			
		||||
            .replace("</answer>", "")
 | 
			
		||||
            .strip()
 | 
			
		||||
        )
 | 
			
		||||
        context.kwargs["planner_thought"] = logic_form_response
 | 
			
		||||
 | 
			
		||||
        sub_queries, logic_forms = parse_logic_form_with_str(logic_form_str)
 | 
			
		||||
        logic_forms = self.logic_node_parser.parse_logic_form_set(
 | 
			
		||||
 | 
			
		||||
@ -11,42 +11,49 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
class ExpressionBuildr(PromptABC):
 | 
			
		||||
    template_zh = (
 | 
			
		||||
        f"今天是{get_now(language='zh')}。"
 | 
			
		||||
        + """\n# instruction
 | 
			
		||||
        + """
 | 
			
		||||
# instruction
 | 
			
		||||
根据给出的问题和数据,编写python代码,输出问题结果。
 | 
			
		||||
为了便于理解,输出从context中提取的数据,输出中间计算过程和结果。
 | 
			
		||||
注意严格根据输入内容进行编写代码,不允许进行假设
 | 
			
		||||
例如伤残等级如果context中未提及,则认为没有被认定为残疾
 | 
			
		||||
如果无法回答问题,直接返回:I don't know.
 | 
			
		||||
从context中提取的数据必须显式赋值,所有计算步骤必须用代码实现,不得隐含推断。
 | 
			
		||||
必须输出中间计算过程和结果,格式为print语句。
 | 
			
		||||
如果context未提供必要数据或无法计算,直接打印"I don't know."
 | 
			
		||||
 | 
			
		||||
# output format
 | 
			
		||||
直接输出python代码,python版本为3.10,不要包含任何其他信息
 | 
			
		||||
严格输出以下结构的python代码(版本3.10):
 | 
			
		||||
1. 数据提取部分:代码中涉及输入的数值需要从context及question中提取,不允许进行假设
 | 
			
		||||
2. 计算过程:分步实现所有数学运算,每个步骤对应独立变量
 | 
			
		||||
3. 输出:每个中间变量和最终结果必须用print语句输出
 | 
			
		||||
 | 
			
		||||
# examples
 | 
			
		||||
## 例子1
 | 
			
		||||
### input
 | 
			
		||||
#### question
 | 
			
		||||
47000元按照万分之1.5一共612天,计算利息,一共多少钱?
 | 
			
		||||
4百万元按照日利率万分之1.5,一共612天,计算利息,一共多少钱?
 | 
			
		||||
#### context
 | 
			
		||||
日利率万分之1.5
 | 
			
		||||
### output
 | 
			
		||||
```python
 | 
			
		||||
# 初始本金
 | 
			
		||||
principal = 47000
 | 
			
		||||
# 初始本金(单位:百万)
 | 
			
		||||
principal = 4  # 单位:百万
 | 
			
		||||
 | 
			
		||||
# 利率(万分之1.5)
 | 
			
		||||
rate = 1.5 / 10000
 | 
			
		||||
# 日利率计算(万分之1.5)
 | 
			
		||||
daily_rate = 1.5 / 10000
 | 
			
		||||
 | 
			
		||||
# 天数
 | 
			
		||||
# 计算周期
 | 
			
		||||
days = 612
 | 
			
		||||
 | 
			
		||||
# 计算年利率
 | 
			
		||||
annual_rate = rate * 365
 | 
			
		||||
# 单日利息计算
 | 
			
		||||
daily_interest = principal * daily_rate
 | 
			
		||||
 | 
			
		||||
# 计算利息
 | 
			
		||||
interest = principal * (annual_rate / 365) * days
 | 
			
		||||
# 累计利息计算
 | 
			
		||||
total_interest = daily_interest * days
 | 
			
		||||
 | 
			
		||||
# 输出总金额(本金+利息)
 | 
			
		||||
total_amount = principal + interest
 | 
			
		||||
# 总金额计算
 | 
			
		||||
total_amount = principal + total_interest
 | 
			
		||||
 | 
			
		||||
print(f"总金额:{total_amount:.2f}元")
 | 
			
		||||
print(f"单日利息:{daily_interest:.2f}百万")
 | 
			
		||||
print(f"累计利息:{total_interest:.2f}百万")
 | 
			
		||||
print(f"总金额:{total_amount:.2f}百万")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 例子2
 | 
			
		||||
@ -70,13 +77,26 @@ revenue_2020 = revenue_2019 * (1 + growth_rate)
 | 
			
		||||
print(f"2020年的预计收入为: {revenue_2020:.2f}万")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 例子3
 | 
			
		||||
### input
 | 
			
		||||
#### question
 | 
			
		||||
47000元按照612天计算利息,本息一共多少钱?
 | 
			
		||||
#### content
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
```python
 | 
			
		||||
# 未给出利率,无法计算
 | 
			
		||||
print("未给出利率,无法计算")
 | 
			
		||||
```
 | 
			
		||||
# input
 | 
			
		||||
## question
 | 
			
		||||
$question
 | 
			
		||||
## context
 | 
			
		||||
$context
 | 
			
		||||
## error
 | 
			
		||||
        $error"""
 | 
			
		||||
$error
 | 
			
		||||
## output
 | 
			
		||||
"""
 | 
			
		||||
    )
 | 
			
		||||
    template_en = (
 | 
			
		||||
        f"Today is {get_now(language='en')}。\n"
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
import logging
 | 
			
		||||
import re
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from kag.common.conf import KAG_PROJECT_CONF
 | 
			
		||||
from kag.common.parser.logic_node_parser import extract_steps_and_actions
 | 
			
		||||
@ -72,14 +73,16 @@ def process_tag_template(text):
 | 
			
		||||
        }
 | 
			
		||||
        clean_text = ""
 | 
			
		||||
        for tag_info in all_tags:
 | 
			
		||||
            content = tag_info[1]
 | 
			
		||||
            if tag_info[0] in xml_tag_template:
 | 
			
		||||
                content = tag_info[1]
 | 
			
		||||
                if "search" == tag_info[0]:
 | 
			
		||||
                    content = process_planning(content)
 | 
			
		||||
                clean_text += xml_tag_template[tag_info[0]][
 | 
			
		||||
                    KAG_PROJECT_CONF.language
 | 
			
		||||
                ].format_map(SafeDict({"content": content}))
 | 
			
		||||
        return remove_xml_tags(clean_text)
 | 
			
		||||
            else:
 | 
			
		||||
                clean_text += content
 | 
			
		||||
        text = remove_xml_tags(clean_text)
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -159,12 +159,12 @@ def render_jinja2_template(template_str, context):
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        template = Template(template_str, undefined=SilentUndefined)
 | 
			
		||||
        return template.render(**context).strip()
 | 
			
		||||
        return template.render(**context)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logging.error(
 | 
			
		||||
            f"Jinja2 rendering failed: {e}, Original template: {template_str}"
 | 
			
		||||
        )
 | 
			
		||||
        return template_str.strip()  # Fallback to raw template string on failure
 | 
			
		||||
        return template_str  # Fallback to raw template string on failure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ReporterABC.register("open_spg_reporter")
 | 
			
		||||
@ -264,12 +264,12 @@ Rerank the documents and take the top {{ chunk_num }}.
 | 
			
		||||
        }
 | 
			
		||||
        self.tag_mapping = {
 | 
			
		||||
            "Graph Show": {
 | 
			
		||||
                "en": "{content}",
 | 
			
		||||
                "zh": "{content}",
 | 
			
		||||
                "en": "{{ content }}",
 | 
			
		||||
                "zh": "{{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "Rewrite query": {
 | 
			
		||||
                "en": "Rethinking question using LLM: {content}",
 | 
			
		||||
                "zh": "根据依赖问题重写子问题: {content}",
 | 
			
		||||
                "en": "Rethinking question using LLM: {{ content }}",
 | 
			
		||||
                "zh": "根据依赖问题重写子问题: {{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "language_setting": {
 | 
			
		||||
                "en": "",
 | 
			
		||||
@ -277,125 +277,153 @@ Rerank the documents and take the top {{ chunk_num }}.
 | 
			
		||||
            },
 | 
			
		||||
            "Iterative planning": {
 | 
			
		||||
                "en": """
 | 
			
		||||
<step status="{status}" title="Global planning">
 | 
			
		||||
<step status="{{status}}" title="Global planning">
 | 
			
		||||
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
                "zh": """
 | 
			
		||||
<step status="{status}" title="思考当前步骤">
 | 
			
		||||
<step status="{{status}}" title="思考当前步骤">
 | 
			
		||||
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
            },
 | 
			
		||||
            "Static planning": {
 | 
			
		||||
                "en": """
 | 
			
		||||
<step status="{status}" title="Global planning">
 | 
			
		||||
<step status="{{status}}" title="Global planning">
 | 
			
		||||
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
                "zh": """
 | 
			
		||||
<step status="{status}" title="思考全局步骤">
 | 
			
		||||
<step status="{{status}}" title="思考全局步骤">
 | 
			
		||||
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
            },
 | 
			
		||||
            "begin_sub_kag_retriever": {
 | 
			
		||||
                "en": "Starting {component_name}: {content} {desc}",
 | 
			
		||||
                "zh": "执行{component_name}: {content} {desc}",
 | 
			
		||||
                "en": "Starting {{component_name}}: {{content}} {{desc}}",
 | 
			
		||||
                "zh": "执行{{component_name}}: {{content}} {{desc}}",
 | 
			
		||||
            },
 | 
			
		||||
            "end_sub_kag_retriever": {
 | 
			
		||||
                "en": " {content}",
 | 
			
		||||
                "zh": " {content}",
 | 
			
		||||
                "en": " {{ content }}",
 | 
			
		||||
                "zh": " {{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "rc_retriever_rewrite": {
 | 
			
		||||
                "en": """
 | 
			
		||||
<step status="{status}" title="Rewriting chunk retriever query">
 | 
			
		||||
<step status="{{status}}" title="Rewriting chunk retriever query">
 | 
			
		||||
 | 
			
		||||
Rewritten question:\n{content}
 | 
			
		||||
Rewritten question:
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
                "zh": """
 | 
			
		||||
<step status="{status}" title="正在根据依赖问题重写检索子问题">
 | 
			
		||||
<step status="{{status}}" title="正在根据依赖问题重写检索子问题">
 | 
			
		||||
 | 
			
		||||
重写问题为:\n\n{content}
 | 
			
		||||
重写问题为:
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
            },
 | 
			
		||||
            "rc_retriever_summary": {
 | 
			
		||||
                "en": "Summarizing retrieved documents,{content}",
 | 
			
		||||
                "zh": "对文档进行总结,{content}",
 | 
			
		||||
                "en": "Summarizing retrieved documents,{{ content }}",
 | 
			
		||||
                "zh": "对文档进行总结,{{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "kg_retriever_summary": {
 | 
			
		||||
                "en": "Summarizing retrieved graph,{content}",
 | 
			
		||||
                "zh": "对召回的知识进行总结,{content}",
 | 
			
		||||
                "en": "Summarizing retrieved graph,{{ content }}",
 | 
			
		||||
                "zh": "对召回的知识进行总结,{{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "retriever_summary": {
 | 
			
		||||
                "en": "Summarizing retrieved documents,{content}",
 | 
			
		||||
                "zh": "对文档进行总结,{content}",
 | 
			
		||||
                "en": "Summarizing retrieved documents,{{ content }}",
 | 
			
		||||
                "zh": "对文档进行总结,{{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "begin_summary": {
 | 
			
		||||
                "en": "Summarizing retrieved information, {content}",
 | 
			
		||||
                "zh": "对检索的信息进行总结, {content}",
 | 
			
		||||
                "en": "Summarizing retrieved information, {{ content }}",
 | 
			
		||||
                "zh": "对检索的信息进行总结, {{ content }}",
 | 
			
		||||
            },
 | 
			
		||||
            "begin_task": {
 | 
			
		||||
                "en": """
 | 
			
		||||
<step status="{status}" title="Starting Task {step}">
 | 
			
		||||
<step status="{{status}}" title="Starting Task {{step}}">
 | 
			
		||||
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
                "zh": """
 | 
			
		||||
<step status="{status}" title="执行 {step}">
 | 
			
		||||
<step status="{{status}}" title="执行 {{step}}">
 | 
			
		||||
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
</step>""",
 | 
			
		||||
{% if status == 'success' %}
 | 
			
		||||
</step>
 | 
			
		||||
{% endif %}""",
 | 
			
		||||
            },
 | 
			
		||||
            "logic_node": {
 | 
			
		||||
                "en": """Translate query to logic form expression
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
```""",
 | 
			
		||||
                "zh": """将query转换成逻辑形式表达
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
```""",
 | 
			
		||||
            },
 | 
			
		||||
            "kag_retriever_result": {
 | 
			
		||||
                "en": "Retrieved documents\n\n{content}",
 | 
			
		||||
                "zh": "检索到的文档\n\n{content}",
 | 
			
		||||
                "en": """Retrieved documents
 | 
			
		||||
{{ content }}""",
 | 
			
		||||
                "zh": """检索到的文档
 | 
			
		||||
{{ content }}""",
 | 
			
		||||
            },
 | 
			
		||||
            "failed_kag_retriever": {
 | 
			
		||||
                "en": """KAG retriever failed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
```
 | 
			
		||||
""",
 | 
			
		||||
                "zh": """KAG检索失败
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
{content}
 | 
			
		||||
{{ content }}
 | 
			
		||||
```
 | 
			
		||||
                """,
 | 
			
		||||
            },
 | 
			
		||||
            "end_math_executor": {
 | 
			
		||||
                "en": "Math executor completed\n\n{content}",
 | 
			
		||||
                "zh": "计算结束\n\n{content}",
 | 
			
		||||
                "en": """Math executor completed
 | 
			
		||||
{{ content }}""",
 | 
			
		||||
                "zh": """计算结束
 | 
			
		||||
{{ content }}""",
 | 
			
		||||
            },
 | 
			
		||||
            "code_generator": {
 | 
			
		||||
                "en": "Generating code\n \n{content}\n",
 | 
			
		||||
                "zh": "正在生成代码\n \n{content}\n",
 | 
			
		||||
                "en": """Generating code
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
""",
 | 
			
		||||
                "zh": """正在生成代码
 | 
			
		||||
{{ content }}
 | 
			
		||||
 | 
			
		||||
""",
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
        task_id = kwargs.get(KAGConstants.KAG_QA_TASK_CONFIG_KEY, None)
 | 
			
		||||
@ -425,7 +453,11 @@ Rewritten question:\n{content}
 | 
			
		||||
        if tpl:
 | 
			
		||||
            format_params = {"content": datas}
 | 
			
		||||
            format_params.update(content_params)
 | 
			
		||||
            datas = tpl.format_map(SafeDict(format_params))
 | 
			
		||||
            if "{" in tpl or "%}" in tpl:
 | 
			
		||||
                rendered = render_jinja2_template(tpl, format_params)
 | 
			
		||||
            else:
 | 
			
		||||
                rendered = tpl.format_map(SafeDict(format_params))
 | 
			
		||||
            datas = rendered
 | 
			
		||||
        elif str(datas).strip() != "":
 | 
			
		||||
            output = str(datas).strip()
 | 
			
		||||
            if output != "":
 | 
			
		||||
@ -516,7 +548,7 @@ Rewritten question:\n{content}
 | 
			
		||||
            if self.last_report.to_dict() == request.to_dict():
 | 
			
		||||
                return
 | 
			
		||||
            logger.info(
 | 
			
		||||
                f"do_report: {content.answer} think={content.think} status={status_enum} ret={ret}"
 | 
			
		||||
                f"do_report: think={content.think} {content.answer} status={status_enum} ret={ret}"
 | 
			
		||||
            )
 | 
			
		||||
            self.last_report = request
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								kag/solver/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								kag/solver/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										168
									
								
								kag/solver/server/asyn_task_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								kag/solver/server/asyn_task_manager.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,168 @@
 | 
			
		||||
import concurrent.futures
 | 
			
		||||
import queue
 | 
			
		||||
import threading
 | 
			
		||||
import time
 | 
			
		||||
import uuid
 | 
			
		||||
import logging
 | 
			
		||||
from cachetools import TTLCache
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncTaskManager:
 | 
			
		||||
    def __init__(self, max_workers=10, ttl=3600):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize async task manager
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            max_workers (int): Maximum number of worker threads
 | 
			
		||||
            ttl (int): Time-to-live for task results in seconds
 | 
			
		||||
        """
 | 
			
		||||
        self.max_workers = max_workers
 | 
			
		||||
        self.task_queue = queue.Queue()
 | 
			
		||||
        self.result_cache = TTLCache(maxsize=1000, ttl=ttl)
 | 
			
		||||
        self.result_cache_lock = threading.Lock()  # Protect cache from race conditions
 | 
			
		||||
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
 | 
			
		||||
        self.workers = [
 | 
			
		||||
            threading.Thread(target=self.worker, daemon=True)
 | 
			
		||||
            for _ in range(max_workers)
 | 
			
		||||
        ]
 | 
			
		||||
        for w in self.workers:
 | 
			
		||||
            w.start()
 | 
			
		||||
 | 
			
		||||
    def worker(self):
 | 
			
		||||
        """Worker thread main loop that processes tasks"""
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                # Get next task from queue with timeout to allow shutdown detection
 | 
			
		||||
                task = self.task_queue.get()
 | 
			
		||||
                task_id, func, args, kwargs = task
 | 
			
		||||
                logger.info(f"Processing task {task_id}")
 | 
			
		||||
                # finish flag
 | 
			
		||||
                if task_id is None:
 | 
			
		||||
                    self.task_queue.task_done()
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                # Update cache with running status
 | 
			
		||||
                with self.result_cache_lock:
 | 
			
		||||
                    self.result_cache[task_id] = {
 | 
			
		||||
                        "task_id": task_id,
 | 
			
		||||
                        "status": "running",
 | 
			
		||||
                        "result": None,
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                # Execute task
 | 
			
		||||
                future = self.executor.submit(func, *args, **kwargs)
 | 
			
		||||
                result = future.result()
 | 
			
		||||
                status = "completed"
 | 
			
		||||
 | 
			
		||||
            except queue.Empty:
 | 
			
		||||
                # Handle queue empty timeout (normal operation)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                # Handle task execution errors
 | 
			
		||||
                result = str(e)
 | 
			
		||||
                status = "failed"
 | 
			
		||||
                logger.error(f"Task {task_id} failed with error: {e}", exc_info=True)
 | 
			
		||||
 | 
			
		||||
            # Store final result in cache
 | 
			
		||||
            try:
 | 
			
		||||
                with self.result_cache_lock:
 | 
			
		||||
                    self.result_cache[task_id] = {
 | 
			
		||||
                        "task_id": task_id,
 | 
			
		||||
                        "status": status,
 | 
			
		||||
                        "result": result,
 | 
			
		||||
                    }
 | 
			
		||||
                logger.info(f"Task {task_id} completed with status: {status}")
 | 
			
		||||
            finally:
 | 
			
		||||
                # Always mark task as done
 | 
			
		||||
                self.task_queue.task_done()
 | 
			
		||||
 | 
			
		||||
    def submit_task(self, func, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Submit a new task to the queue
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            func: Callable function to execute
 | 
			
		||||
            *args: Positional arguments for the function
 | 
			
		||||
            **kwargs: Keyword arguments for the function
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: Unique task ID
 | 
			
		||||
        """
 | 
			
		||||
        task_id = str(uuid.uuid4())
 | 
			
		||||
        self.task_queue.put((task_id, func, args, kwargs))
 | 
			
		||||
        return task_id
 | 
			
		||||
 | 
			
		||||
    def get_task_result(self, task_id):
 | 
			
		||||
        """
 | 
			
		||||
        Get result for a specific task
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            task_id (str): Unique task identifier
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dict: Task result information or expired status
 | 
			
		||||
        """
 | 
			
		||||
        with self.result_cache_lock:
 | 
			
		||||
            return self.result_cache.get(
 | 
			
		||||
                task_id,
 | 
			
		||||
                {
 | 
			
		||||
                    "task_id": task_id,
 | 
			
		||||
                    "status": "failed",
 | 
			
		||||
                    "result": "Result not found or expired",
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
        """Gracefully shutdown all worker threads and executors"""
 | 
			
		||||
        # Send shutdown signals
 | 
			
		||||
        for _ in range(self.max_workers):
 | 
			
		||||
            self.task_queue.put((None, None, (), {}))
 | 
			
		||||
 | 
			
		||||
        # Wait for queue to empty and workers to terminate
 | 
			
		||||
        self.task_queue.join()
 | 
			
		||||
 | 
			
		||||
        # Shutdown executors
 | 
			
		||||
        self.executor.shutdown(wait=True)
 | 
			
		||||
        for worker in self.workers:
 | 
			
		||||
            worker.join(timeout=5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Global async task manager instance
 | 
			
		||||
asyn_task = AsyncTaskManager()
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Configure logging
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
 | 
			
		||||
    # Create task manager instance
 | 
			
		||||
    task_manager = AsyncTaskManager(max_workers=5, ttl=600)
 | 
			
		||||
 | 
			
		||||
    # Example task function
 | 
			
		||||
    def example_task(x, y):
 | 
			
		||||
        time.sleep(1)  # Simulate work
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    # Submit test tasks
 | 
			
		||||
    task_ids = [task_manager.submit_task(example_task, i, i + 1) for i in range(6)]
 | 
			
		||||
 | 
			
		||||
    # Monitor task progress
 | 
			
		||||
    try:
 | 
			
		||||
        while True:
 | 
			
		||||
            time.sleep(1)
 | 
			
		||||
            if all(
 | 
			
		||||
                "completed" in task_manager.get_task_result(tid)["status"]
 | 
			
		||||
                for tid in task_ids
 | 
			
		||||
            ):
 | 
			
		||||
                break
 | 
			
		||||
    except KeyboardInterrupt:
 | 
			
		||||
        logger.info("Shutting down due to user interrupt")
 | 
			
		||||
 | 
			
		||||
    # Print results
 | 
			
		||||
    for task_id in task_ids:
 | 
			
		||||
        print(f"Task {task_id} result: {task_manager.get_task_result(task_id)}")
 | 
			
		||||
 | 
			
		||||
    # Clean up resources
 | 
			
		||||
    task_manager.shutdown()
 | 
			
		||||
							
								
								
									
										0
									
								
								kag/solver/server/example/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								kag/solver/server/example/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										63
									
								
								kag/solver/server/main_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								kag/solver/server/main_server.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
import uvicorn
 | 
			
		||||
 | 
			
		||||
from kag.solver.main_solver import SolverMain
 | 
			
		||||
from kag.solver.server.asyn_task_manager import AsyncTaskManager
 | 
			
		||||
from kag.solver.server.model.task_req import FeatureRequest, TaskReq
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_main_solver(task: TaskReq):
 | 
			
		||||
    return SolverMain().invoke(
 | 
			
		||||
        project_id=task.project_id,
 | 
			
		||||
        task_id=task.req_id,
 | 
			
		||||
        query=task.req.query,
 | 
			
		||||
        is_report=task.req.report,
 | 
			
		||||
        host_addr=task.req.host_addr,
 | 
			
		||||
        app_id=task.app_id,
 | 
			
		||||
        params=task.config,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KAGSolverServer:
 | 
			
		||||
    def __init__(self, service_name: str):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize a FastAPI service instance
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            service_name (str): Service name, determines which routing logic to load
 | 
			
		||||
        """
 | 
			
		||||
        self.service_name = service_name
 | 
			
		||||
        self.app = FastAPI(title=f"{service_name} API")
 | 
			
		||||
 | 
			
		||||
        # Bind routes according to service name
 | 
			
		||||
        self._setup_routes()
 | 
			
		||||
        self.async_manager = AsyncTaskManager()
 | 
			
		||||
 | 
			
		||||
    def sync_task(self, task: TaskReq):
 | 
			
		||||
        if task.cmd == "submit":
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                "success": True,
 | 
			
		||||
                "status": "init",
 | 
			
		||||
                "result": self.async_manager.submit_task(run_main_solver, task),
 | 
			
		||||
            }
 | 
			
		||||
        elif task.cmd == "query":
 | 
			
		||||
            return self.async_manager.get_task_result(task_id=task.req_id)
 | 
			
		||||
        else:
 | 
			
		||||
            return {
 | 
			
		||||
                "success": False,
 | 
			
		||||
                "status": "failed",
 | 
			
		||||
                "result": f"invalid input cmd {task.cmd}",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
    def _setup_routes(self):
 | 
			
		||||
        """Dynamically bind routes according to service name"""
 | 
			
		||||
 | 
			
		||||
        @self.app.post("/process")
 | 
			
		||||
        def process(req: FeatureRequest):
 | 
			
		||||
            return self.sync_task(task=req.features.task_req)
 | 
			
		||||
 | 
			
		||||
    def run(self, host="0.0.0.0", port=8000):
 | 
			
		||||
        """Start the service"""
 | 
			
		||||
        print(f"Starting {self.service_name} service on {host}:{port}")
 | 
			
		||||
        uvicorn.run(self.app, host=host, port=port)
 | 
			
		||||
							
								
								
									
										0
									
								
								kag/solver/server/model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								kag/solver/server/model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										122
									
								
								kag/solver/server/model/task_req.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								kag/solver/server/model/task_req.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,122 @@
 | 
			
		||||
import json
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel, model_validator, field_serializer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReqBody(BaseModel):
 | 
			
		||||
    """Request body model containing query parameters"""
 | 
			
		||||
 | 
			
		||||
    query: str = ""
 | 
			
		||||
    report: bool = True
 | 
			
		||||
    host_addr: str = ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TaskReq(BaseModel):
 | 
			
		||||
    """Task request model with validation logic"""
 | 
			
		||||
 | 
			
		||||
    app_id: int = ""
 | 
			
		||||
    project_id: int = 0
 | 
			
		||||
    req_id: str = ""
 | 
			
		||||
    cmd: str = ""
 | 
			
		||||
    mode: str = ""
 | 
			
		||||
    req: str = None
 | 
			
		||||
    config: str = "{}"
 | 
			
		||||
 | 
			
		||||
    @model_validator(mode="after")
 | 
			
		||||
    def parse_req_to_req_body(self):
 | 
			
		||||
        """Parse req string to ReqBody object and process config field"""
 | 
			
		||||
        try:
 | 
			
		||||
            import json
 | 
			
		||||
 | 
			
		||||
            if isinstance(self.req, str):
 | 
			
		||||
                req_body_dict = json.loads(self.req)
 | 
			
		||||
                self.req = ReqBody(**req_body_dict)
 | 
			
		||||
            if isinstance(self.config, str) and self.config:
 | 
			
		||||
                config_dict = json.loads(self.config)
 | 
			
		||||
                self.config = config_dict
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise ValueError(f"Failed to parse 'req' field to ReqBody: {e}")
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    @field_serializer("req")
 | 
			
		||||
    def serialize_req(self, value: object) -> object:
 | 
			
		||||
        """Serialize ReqBody back to JSON string"""
 | 
			
		||||
        if isinstance(value, ReqBody):
 | 
			
		||||
            return value.model_dump_json()
 | 
			
		||||
        return value  # Return as-is if already a string
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Request model with TaskReq parsing capability
 | 
			
		||||
class Request(BaseModel):
 | 
			
		||||
    """Container model for task request data"""
 | 
			
		||||
 | 
			
		||||
    in_string: str
 | 
			
		||||
    task_req: Optional[TaskReq] = None
 | 
			
		||||
 | 
			
		||||
    @model_validator(mode="after")
 | 
			
		||||
    def parse_in_string_to_task_req(self):
 | 
			
		||||
        """Convert in_string JSON string to TaskReq object"""
 | 
			
		||||
        try:
 | 
			
		||||
            import json
 | 
			
		||||
 | 
			
		||||
            task_req_dict = json.loads(self.in_string)
 | 
			
		||||
            self.task_req = TaskReq(**task_req_dict)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            raise ValueError(f"Invalid TaskReq JSON string: {e}")
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FeatureRequest(BaseModel):
 | 
			
		||||
    """Top-level request wrapper with features container"""
 | 
			
		||||
 | 
			
		||||
    features: Request
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    def feature_request_parsing():
 | 
			
		||||
        """Demonstrate nested model parsing workflow"""
 | 
			
		||||
        # Build innermost ReqBody JSON string
 | 
			
		||||
        req_body = ReqBody(
 | 
			
		||||
            query="阿里巴巴财报中,2024年-截至9月30日止六个月的收入是多少?其中云智能集团收入是多少?占比是多少",
 | 
			
		||||
            report=True,
 | 
			
		||||
            host_addr="https://spg.alipay.com",
 | 
			
		||||
        )
 | 
			
		||||
        req_body_json = json.dumps(req_body.model_dump())
 | 
			
		||||
 | 
			
		||||
        # Build TaskReq dictionary and serialize to string
 | 
			
		||||
        task_req = TaskReq(
 | 
			
		||||
            req_id="9400110",
 | 
			
		||||
            cmd="submit",
 | 
			
		||||
            mode="async",
 | 
			
		||||
            req=req_body_json,
 | 
			
		||||
            app_id="app_id",
 | 
			
		||||
            project_id=4200050,
 | 
			
		||||
            config={"timeout": 10},
 | 
			
		||||
        )
 | 
			
		||||
        task_req_json = json.dumps(task_req.model_dump())
 | 
			
		||||
 | 
			
		||||
        # Construct final FeatureRequest JSON string
 | 
			
		||||
        input_data = {"features": {"in_string": task_req_json}}
 | 
			
		||||
 | 
			
		||||
        # Deserialize to FeatureRequest model
 | 
			
		||||
        feature_request = FeatureRequest(**input_data)
 | 
			
		||||
 | 
			
		||||
        # Validate in_string parsed to TaskReq
 | 
			
		||||
        assert isinstance(feature_request.features.task_req, TaskReq)
 | 
			
		||||
        assert feature_request.features.task_req.req_id == "abc123"
 | 
			
		||||
        assert feature_request.features.task_req.cmd == "run"
 | 
			
		||||
        assert feature_request.features.task_req.mode == "sync"
 | 
			
		||||
        assert feature_request.features.task_req.config == {"timeout": 10}
 | 
			
		||||
 | 
			
		||||
        # Validate TaskReq.req parsed to ReqBody
 | 
			
		||||
        req_body_parsed = feature_request.features.task_req.req
 | 
			
		||||
        assert isinstance(req_body_parsed, ReqBody)
 | 
			
		||||
        assert req_body_parsed.query == "What is AI?"
 | 
			
		||||
        assert req_body_parsed.report is True
 | 
			
		||||
        assert req_body_parsed.host_addr == "localhost"
 | 
			
		||||
 | 
			
		||||
        print("✅ All assertions passed!")
 | 
			
		||||
 | 
			
		||||
    feature_request_parsing()
 | 
			
		||||
							
								
								
									
										1
									
								
								kag/solver/server/requirement.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								kag/solver/server/requirement.txt
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
fastapi
 | 
			
		||||
@ -48,13 +48,16 @@ class Environment:
 | 
			
		||||
    @property
 | 
			
		||||
    def config(self):
 | 
			
		||||
 | 
			
		||||
        closest_config = self._closest_config()
 | 
			
		||||
        if not hasattr(self, "_config_path") or self._config_path != closest_config:
 | 
			
		||||
            self._config_path = closest_config
 | 
			
		||||
            self._config = self.get_config()
 | 
			
		||||
        try:
 | 
			
		||||
            closest_config = self._closest_config()
 | 
			
		||||
            if not hasattr(self, "_config_path") or self._config_path != closest_config:
 | 
			
		||||
                self._config_path = closest_config
 | 
			
		||||
                self._config = self.get_config()
 | 
			
		||||
 | 
			
		||||
        if self._config is None:
 | 
			
		||||
            self._config = self.get_config()
 | 
			
		||||
            if self._config is None:
 | 
			
		||||
                self._config = self.get_config()
 | 
			
		||||
        except:
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
        return self._config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ numpy>=1.23.1
 | 
			
		||||
pypdf
 | 
			
		||||
pandas
 | 
			
		||||
pycryptodome
 | 
			
		||||
markdown
 | 
			
		||||
markdown==3.7
 | 
			
		||||
bs4
 | 
			
		||||
protobuf==3.20.1
 | 
			
		||||
neo4j
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										64
									
								
								tests/unit/common/kag_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								tests/unit/common/kag_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
from kag.common.utils import extract_tag_content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_extra_tag():
 | 
			
		||||
    test_cases = [
 | 
			
		||||
        {
 | 
			
		||||
            "input": "<tag1>abced</tag1>some word<tag2>other tags</tag2>",
 | 
			
		||||
            "expected": [("tag1", "abced"), ("", "some word"), ("tag2", "other tags")],
 | 
			
		||||
            "description": "基本闭合标签与无标签文本混合",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "<p>Hello <b>world</b> this is <i>test</i>",
 | 
			
		||||
            "expected": [
 | 
			
		||||
                ("p", "Hello "),
 | 
			
		||||
                ("b", "world"),
 | 
			
		||||
                ("", " this is "),
 | 
			
		||||
                ("i", "test"),
 | 
			
		||||
            ],
 | 
			
		||||
            "description": "混合闭合与未闭合标签",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "plain text without any tags",
 | 
			
		||||
            "expected": [("", "plain text without any tags")],
 | 
			
		||||
            "description": "纯文本无标签",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "<div>\n    Line 1\n    <span>Line 2</span>\n    Line 3\n</div>",
 | 
			
		||||
            "expected": [
 | 
			
		||||
                ("div", "\n    Line 1\n    <span>Line 2</span>\n    Line 3\n")
 | 
			
		||||
            ],
 | 
			
		||||
            "description": "多行内容和空白处理",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "<a>A</a><b>B</b><c>C</c>",
 | 
			
		||||
            "expected": [("a", "A"), ("b", "B"), ("c", "C")],
 | 
			
		||||
            "description": "连续多个闭合标签",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "<title>My Document</title><content>This is the content",
 | 
			
		||||
            "expected": [("title", "My Document"), ("content", "This is the content")],
 | 
			
		||||
            "description": "未闭合标签(EOF结尾)",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "<log>Error: &*^%$#@!;</log><note>End of log</note>",
 | 
			
		||||
            "expected": [("log", "Error: &*^%$#@!;"), ("note", "End of log")],
 | 
			
		||||
            "description": "含特殊字符的内容",
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "input": "",
 | 
			
		||||
            "expected": [],
 | 
			
		||||
            "description": "空字符串输入",
 | 
			
		||||
        },
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    for i, test in enumerate(test_cases):
 | 
			
		||||
        result = extract_tag_content(test["input"])
 | 
			
		||||
        assert (
 | 
			
		||||
            result == test["expected"]
 | 
			
		||||
        ), f"Test {i+1} failed: {test['description']}\nGot: {result}\nExpected: {test['expected']}"
 | 
			
		||||
        print(f"Test {i+1} passed: {test['description']}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    run_extra_tag()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user