diff --git a/KAG_VERSION b/KAG_VERSION index a3df0a69..8adc70fd 100644 --- a/KAG_VERSION +++ b/KAG_VERSION @@ -1 +1 @@ -0.8.0 +0.8.0 \ No newline at end of file diff --git a/kag/common/utils.py b/kag/common/utils.py index 7614eea7..40757bf9 100644 --- a/kag/common/utils.py +++ b/kag/common/utils.py @@ -463,9 +463,17 @@ def resolve_instance( def extract_tag_content(text): - # 匹配之间的内容,支持任意标签名 - matches = re.findall(r"<([^>]+)>(.*?)", text, flags=re.DOTALL) - return [(tag, content.strip()) for tag, content in matches] + pattern = r"<(\w+)\b[^>]*>(.*?)|<(\w+)\b[^>]*>([^<]*)|([^<]+)" + results = [] + for match in re.finditer(pattern, text, re.DOTALL): + tag1, content1, tag2, content2, raw_text = match.groups() + if tag1: + results.append((tag1, content1)) # 保留原始内容(含空格) + elif tag2: + results.append((tag2, content2)) # 保留原始内容(含空格) + elif raw_text: + results.append(("", raw_text)) # 保留原始空格 + return results def extract_specific_tag_content(text, tag): diff --git a/kag/solver/executor/math/py_based_math_executor.py b/kag/solver/executor/math/py_based_math_executor.py index 544ff674..6c172bd0 100644 --- a/kag/solver/executor/math/py_based_math_executor.py +++ b/kag/solver/executor/math/py_based_math_executor.py @@ -131,9 +131,11 @@ class PyBasedMathExecutor(ExecutorABC): ) parent_results = format_task_dep_context(task.parents) - parent_results = "\n".join(parent_results) + coder_content = context.kwargs.get("planner_thought", "") + "\n\n".join( + parent_results + ) - parent_results += "\n\n" + contents + coder_content += "\n\n" + contents tries = self.tries error = None @@ -141,7 +143,7 @@ class PyBasedMathExecutor(ExecutorABC): tries -= 1 rst, error, code = self.run_once( math_query, - parent_results, + coder_content, error, segment_name=tag_id, tag_name=f"{task_query}_code_generator", diff --git a/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py b/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py index 455f1eeb..721f8323 100644 --- a/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py +++ b/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py @@ -42,6 +42,15 @@ from kag.solver.utils import init_prompt_with_fallback logger = logging.getLogger() +def _wrapped_invoke(retriever, task, context, segment_name, kwargs): + start_time = time.time() + output = retriever.invoke( + task, context=context, segment_name=segment_name, **kwargs + ) + elapsed_time = time.time() - start_time + return output, elapsed_time + + @ExecutorABC.register("kag_hybrid_retrieval_executor") class KAGHybridRetrievalExecutor(ExecutorABC): def __init__( @@ -76,6 +85,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC): self.context_select_prompt = context_select_prompt or PromptABC.from_config( {"type": "context_select_prompt"} ) + self.with_llm_select = kwargs.get("with_llm_select", True) @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1)) def context_select_call(self, variables): @@ -152,22 +162,30 @@ class KAGHybridRetrievalExecutor(ExecutorABC): "FINISH", component_name=retriever.name, ) - + # Record start time before submitting the task + start_time = time.time() # Prepare function and submit to thread pool func = partial( - retriever.invoke, + _wrapped_invoke, + retriever, task, - context=context, - segment_name=tag_id, - **kwargs, + context, + tag_id, + kwargs.copy(), ) future = executor.submit(func) + # Save future, retriever, and start_time together futures.append((future, retriever)) # Collect results from each future for future, retriever in futures: try: - output = future.result() # Wait for result + output, elapsed_time = future.result() # Wait for result + + # Log the elapsed time for this retriever + logger.info( + f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds" + ) outputs.append(output) # Log data report after successful execution @@ -241,13 +259,18 @@ class KAGHybridRetrievalExecutor(ExecutorABC): selected_rel = list(set(selected_rel)) formatted_docs = [str(rel) for rel in selected_rel] if retrieved_data.chunks: - try: - selected_chunks = self.context_select(task_query, retrieved_data.chunks) - except Exception as e: - logger.warning( - f"select context failed {e}, we use default top 10 to summary", - exc_info=True, - ) + if self.with_llm_select: + try: + selected_chunks = self.context_select( + task_query, retrieved_data.chunks + ) + except Exception as e: + logger.warning( + f"select context failed {e}, we use default top 10 to summary", + exc_info=True, + ) + selected_chunks = retrieved_data.chunks[:10] + else: selected_chunks = retrieved_data.chunks[:10] for doc in selected_chunks: formatted_docs.append(f"{doc.content}") @@ -280,68 +303,81 @@ class KAGHybridRetrievalExecutor(ExecutorABC): task_query = task.arguments["query"] tag_id = f"{task_query}_begin_task" - self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name) + self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name) try: - retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs) - except Exception as e: - logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True) - retrieved_data = RetrieverOutput( - retriever_method=self.schema().get("name", ""), err_msg=str(e) - ) - - self.report_content( - reporter, - "reference", - f"{task_query}_kag_retriever_result", - retrieved_data, - "FINISH", - ) - - retrieved_data.task = task - logical_node = task.arguments.get("logic_form_node", None) - if ( - logical_node - and isinstance(logical_node, GetSPONode) - and retrieved_data.summary - ): - if isinstance(retrieved_data.summary, str): - target_answer = retrieved_data.summary.split("Answer:")[-1].strip() - s_entities = context.variables_graph.get_entity_by_alias( - logical_node.s.alias_name + try: + retrieved_data = self.do_main( + task_query, tag_id, task, context, **kwargs ) - if ( - not s_entities - and not logical_node.s.get_mention_name() - and isinstance(logical_node.s, SPOEntity) - ): - logical_node.s.entity_name = target_answer - context.kwargs[logical_node.s.alias_name] = logical_node.s - o_entities = context.variables_graph.get_entity_by_alias( - logical_node.o.alias_name + except Exception as e: + logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True) + retrieved_data = RetrieverOutput( + retriever_method=self.schema().get("name", ""), err_msg=str(e) ) - if ( - not o_entities - and not logical_node.o.get_mention_name() - and isinstance(logical_node.o, SPOEntity) - ): - logical_node.o.entity_name = target_answer - context.kwargs[logical_node.o.alias_name] = logical_node.o - context.variables_graph.add_answered_alias( - logical_node.s.alias_name.alias_name, retrieved_data.summary - ) - context.variables_graph.add_answered_alias( - logical_node.p.alias_name.alias_name, retrieved_data.summary - ) - context.variables_graph.add_answered_alias( - logical_node.o.alias_name.alias_name, retrieved_data.summary + self.report_content( + reporter, + "reference", + f"{task_query}_kag_retriever_result", + retrieved_data, + "FINISH", ) - task.update_result(retrieved_data) - logger.debug( - f"kag hybrid retrieval {task_query} cost={time.time() - start_time}" - ) - return retrieved_data + retrieved_data.task = task + logical_node = task.arguments.get("logic_form_node", None) + if ( + logical_node + and isinstance(logical_node, GetSPONode) + and retrieved_data.summary + ): + if isinstance(retrieved_data.summary, str): + target_answer = retrieved_data.summary.split("Answer:")[-1].strip() + s_entities = context.variables_graph.get_entity_by_alias( + logical_node.s.alias_name + ) + if ( + not s_entities + and not logical_node.s.get_mention_name() + and isinstance(logical_node.s, SPOEntity) + ): + logical_node.s.entity_name = target_answer + context.kwargs[logical_node.s.alias_name] = logical_node.s + o_entities = context.variables_graph.get_entity_by_alias( + logical_node.o.alias_name + ) + if ( + not o_entities + and not logical_node.o.get_mention_name() + and isinstance(logical_node.o, SPOEntity) + ): + logical_node.o.entity_name = target_answer + context.kwargs[logical_node.o.alias_name] = logical_node.o + + context.variables_graph.add_answered_alias( + logical_node.s.alias_name.alias_name, retrieved_data.summary + ) + context.variables_graph.add_answered_alias( + logical_node.p.alias_name.alias_name, retrieved_data.summary + ) + context.variables_graph.add_answered_alias( + logical_node.o.alias_name.alias_name, retrieved_data.summary + ) + + task.update_result(retrieved_data) + logger.debug( + f"kag hybrid retrieval {task_query} cost={time.time() - start_time}" + ) + return retrieved_data + finally: + self.report_content( + reporter, + "thinker", + tag_id, + "", + "FINISH", + step=task.name, + overwrite=False, + ) def schema(self) -> dict: """Function schema definition for OpenAI Function Calling @@ -403,7 +439,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC): node_type=chunk.properties.get("__labels__"), ) entity_prop = dict(chunk.properties) if chunk.properties else {} - entity_prop["content"] = chunk.content + entity_prop["content"] = f"{chunk.content[:10]}..." entity_prop["score"] = chunk.score entity.prop = Prop.from_dict(entity_prop, "Chunk", None) chunk_graph.append(entity) diff --git a/kag/solver/main_solver.py b/kag/solver/main_solver.py index f2fb4b16..32120a3a 100644 --- a/kag/solver/main_solver.py +++ b/kag/solver/main_solver.py @@ -140,8 +140,6 @@ def get_pipeline_conf(use_pipeline_name, config): raise RuntimeError("mcpServers not found in config.") default_solver_pipeline["executors"] = mcp_executors - # update KAG_CONFIG - KAG_CONFIG.update_conf(default_pipeline_conf) return default_solver_pipeline @@ -167,8 +165,11 @@ async def do_qa_pipeline( f"Knowledge base with id {kb_project_id} not found in qa_config['kb']" ) continue - - for index_name in matched_kb.get("index_list", []): + index_list = matched_kb.get("index_list", []) + if use_pipeline in ["default_pipeline"]: + # we only use chunk index + index_list = ["chunk_index"] + for index_name in index_list: index_manager = KAGIndexManager.from_config( { "type": index_name, @@ -339,7 +340,7 @@ class SolverMain: def invoke( self, project_id: int, - task_id: int, + task_id, query: str, session_id: str = "0", is_report=True, diff --git a/kag/solver/pipelineconf/naive_rag.yaml b/kag/solver/pipelineconf/naive_rag.yaml index 2cc249e4..84d4b8a6 100644 --- a/kag/solver/pipelineconf/naive_rag.yaml +++ b/kag/solver/pipelineconf/naive_rag.yaml @@ -3,20 +3,17 @@ pipeline_name: default_pipeline #------------kag-solver configuration start----------------# - -chunk_retrieved_executor: &chunk_retrieved_executor_conf - type: chunk_retrieved_executor - top_k: 10 - retriever: - type: vector_chunk_retriever - score_threshold: 0.65 - vectorize_model: "{vectorize_model}" - +kag_retriever_executor: &kag_retriever_executor_conf + type: kag_hybrid_retrieval_executor + retrievers: "{retrievers}" + merger: + type: kag_merger + enable_summary: false solver_pipeline: type: naive_rag_pipeline executors: - - *chunk_retrieved_executor_conf + - *kag_retriever_executor_conf generator: type: llm_index_generator llm_client: "{chat_llm}" diff --git a/kag/solver/planner/kag_model_planner.py b/kag/solver/planner/kag_model_planner.py index 62bcd6c2..008ac3c1 100644 --- a/kag/solver/planner/kag_model_planner.py +++ b/kag/solver/planner/kag_model_planner.py @@ -186,6 +186,7 @@ class KAGModelPlanner(PlannerABC): .replace("", "") .strip() ) + context.kwargs["planner_thought"] = logic_form_response sub_queries, logic_forms = parse_logic_form_with_str(logic_form_str) logic_forms = self.logic_node_parser.parse_logic_form_set( diff --git a/kag/solver/prompt/expression_builder.py b/kag/solver/prompt/expression_builder.py index 4002bb48..36c25dbd 100644 --- a/kag/solver/prompt/expression_builder.py +++ b/kag/solver/prompt/expression_builder.py @@ -11,42 +11,49 @@ logger = logging.getLogger(__name__) class ExpressionBuildr(PromptABC): template_zh = ( f"今天是{get_now(language='zh')}。" - + """\n# instruction + + """ +# instruction 根据给出的问题和数据,编写python代码,输出问题结果。 -为了便于理解,输出从context中提取的数据,输出中间计算过程和结果。 -注意严格根据输入内容进行编写代码,不允许进行假设 -例如伤残等级如果context中未提及,则认为没有被认定为残疾 -如果无法回答问题,直接返回:I don't know. +从context中提取的数据必须显式赋值,所有计算步骤必须用代码实现,不得隐含推断。 +必须输出中间计算过程和结果,格式为print语句。 +如果context未提供必要数据或无法计算,直接打印"I don't know." # output format -直接输出python代码,python版本为3.10,不要包含任何其他信息 +严格输出以下结构的python代码(版本3.10): +1. 数据提取部分:代码中涉及输入的数值需要从context及question中提取,不允许进行假设 +2. 计算过程:分步实现所有数学运算,每个步骤对应独立变量 +3. 输出:每个中间变量和最终结果必须用print语句输出 # examples ## 例子1 ### input #### question -47000元按照万分之1.5一共612天,计算利息,一共多少钱? +4百万元按照日利率万分之1.5,一共612天,计算利息,一共多少钱? +#### context +日利率万分之1.5 ### output ```python -# 初始本金 -principal = 47000 +# 初始本金(单位:百万) +principal = 4 # 单位:百万 -# 利率(万分之1.5) -rate = 1.5 / 10000 +# 日利率计算(万分之1.5) +daily_rate = 1.5 / 10000 -# 天数 +# 计算周期 days = 612 -# 计算年利率 -annual_rate = rate * 365 +# 单日利息计算 +daily_interest = principal * daily_rate -# 计算利息 -interest = principal * (annual_rate / 365) * days +# 累计利息计算 +total_interest = daily_interest * days -# 输出总金额(本金+利息) -total_amount = principal + interest +# 总金额计算 +total_amount = principal + total_interest -print(f"总金额:{total_amount:.2f}元") +print(f"单日利息:{daily_interest:.2f}百万") +print(f"累计利息:{total_interest:.2f}百万") +print(f"总金额:{total_amount:.2f}百万") ``` ## 例子2 @@ -70,13 +77,26 @@ revenue_2020 = revenue_2019 * (1 + growth_rate) print(f"2020年的预计收入为: {revenue_2020:.2f}万") ``` +## 例子3 +### input +#### question +47000元按照612天计算利息,本息一共多少钱? +#### content + +### output +```python +# 未给出利率,无法计算 +print("未给出利率,无法计算") +``` # input ## question $question ## context $context ## error - $error""" +$error +## output +""" ) template_en = ( f"Today is {get_now(language='en')}。\n" diff --git a/kag/solver/reporter/open_spg_kag_model_reporter.py b/kag/solver/reporter/open_spg_kag_model_reporter.py index e5ceae7d..d569a16b 100644 --- a/kag/solver/reporter/open_spg_kag_model_reporter.py +++ b/kag/solver/reporter/open_spg_kag_model_reporter.py @@ -1,5 +1,6 @@ import logging import re +import time from kag.common.conf import KAG_PROJECT_CONF from kag.common.parser.logic_node_parser import extract_steps_and_actions @@ -72,14 +73,16 @@ def process_tag_template(text): } clean_text = "" for tag_info in all_tags: + content = tag_info[1] if tag_info[0] in xml_tag_template: - content = tag_info[1] if "search" == tag_info[0]: content = process_planning(content) clean_text += xml_tag_template[tag_info[0]][ KAG_PROJECT_CONF.language ].format_map(SafeDict({"content": content})) - return remove_xml_tags(clean_text) + else: + clean_text += content + text = remove_xml_tags(clean_text) return text diff --git a/kag/solver/reporter/open_spg_reporter.py b/kag/solver/reporter/open_spg_reporter.py index ef2ae01a..8d0d9c95 100644 --- a/kag/solver/reporter/open_spg_reporter.py +++ b/kag/solver/reporter/open_spg_reporter.py @@ -159,12 +159,12 @@ def render_jinja2_template(template_str, context): """ try: template = Template(template_str, undefined=SilentUndefined) - return template.render(**context).strip() + return template.render(**context) except Exception as e: logging.error( f"Jinja2 rendering failed: {e}, Original template: {template_str}" ) - return template_str.strip() # Fallback to raw template string on failure + return template_str # Fallback to raw template string on failure @ReporterABC.register("open_spg_reporter") @@ -264,12 +264,12 @@ Rerank the documents and take the top {{ chunk_num }}. } self.tag_mapping = { "Graph Show": { - "en": "{content}", - "zh": "{content}", + "en": "{{ content }}", + "zh": "{{ content }}", }, "Rewrite query": { - "en": "Rethinking question using LLM: {content}", - "zh": "根据依赖问题重写子问题: {content}", + "en": "Rethinking question using LLM: {{ content }}", + "zh": "根据依赖问题重写子问题: {{ content }}", }, "language_setting": { "en": "", @@ -277,125 +277,153 @@ Rerank the documents and take the top {{ chunk_num }}. }, "Iterative planning": { "en": """ - + -{content} +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", "zh": """ - + -{content} +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", }, "Static planning": { "en": """ - + -{content} +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", "zh": """ - + -{content} +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", }, "begin_sub_kag_retriever": { - "en": "Starting {component_name}: {content} {desc}", - "zh": "执行{component_name}: {content} {desc}", + "en": "Starting {{component_name}}: {{content}} {{desc}}", + "zh": "执行{{component_name}}: {{content}} {{desc}}", }, "end_sub_kag_retriever": { - "en": " {content}", - "zh": " {content}", + "en": " {{ content }}", + "zh": " {{ content }}", }, "rc_retriever_rewrite": { "en": """ - + -Rewritten question:\n{content} +Rewritten question: +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", "zh": """ - + -重写问题为:\n\n{content} +重写问题为: +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", }, "rc_retriever_summary": { - "en": "Summarizing retrieved documents,{content}", - "zh": "对文档进行总结,{content}", + "en": "Summarizing retrieved documents,{{ content }}", + "zh": "对文档进行总结,{{ content }}", }, "kg_retriever_summary": { - "en": "Summarizing retrieved graph,{content}", - "zh": "对召回的知识进行总结,{content}", + "en": "Summarizing retrieved graph,{{ content }}", + "zh": "对召回的知识进行总结,{{ content }}", }, "retriever_summary": { - "en": "Summarizing retrieved documents,{content}", - "zh": "对文档进行总结,{content}", + "en": "Summarizing retrieved documents,{{ content }}", + "zh": "对文档进行总结,{{ content }}", }, "begin_summary": { - "en": "Summarizing retrieved information, {content}", - "zh": "对检索的信息进行总结, {content}", + "en": "Summarizing retrieved information, {{ content }}", + "zh": "对检索的信息进行总结, {{ content }}", }, "begin_task": { "en": """ - + -{content} +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", "zh": """ - + -{content} +{{ content }} -""", +{% if status == 'success' %} + +{% endif %}""", }, "logic_node": { "en": """Translate query to logic form expression ```json -{content} +{{ content }} ```""", "zh": """将query转换成逻辑形式表达 ```json -{content} +{{ content }} ```""", }, "kag_retriever_result": { - "en": "Retrieved documents\n\n{content}", - "zh": "检索到的文档\n\n{content}", + "en": """Retrieved documents +{{ content }}""", + "zh": """检索到的文档 +{{ content }}""", }, "failed_kag_retriever": { "en": """KAG retriever failed ```json -{content} +{{ content }} ``` """, "zh": """KAG检索失败 ```json -{content} +{{ content }} ``` """, }, "end_math_executor": { - "en": "Math executor completed\n\n{content}", - "zh": "计算结束\n\n{content}", + "en": """Math executor completed +{{ content }}""", + "zh": """计算结束 +{{ content }}""", }, "code_generator": { - "en": "Generating code\n \n{content}\n", - "zh": "正在生成代码\n \n{content}\n", + "en": """Generating code +{{ content }} + +""", + "zh": """正在生成代码 +{{ content }} + +""", }, } task_id = kwargs.get(KAGConstants.KAG_QA_TASK_CONFIG_KEY, None) @@ -425,7 +453,11 @@ Rewritten question:\n{content} if tpl: format_params = {"content": datas} format_params.update(content_params) - datas = tpl.format_map(SafeDict(format_params)) + if "{" in tpl or "%}" in tpl: + rendered = render_jinja2_template(tpl, format_params) + else: + rendered = tpl.format_map(SafeDict(format_params)) + datas = rendered elif str(datas).strip() != "": output = str(datas).strip() if output != "": @@ -516,7 +548,7 @@ Rewritten question:\n{content} if self.last_report.to_dict() == request.to_dict(): return logger.info( - f"do_report: {content.answer} think={content.think} status={status_enum} ret={ret}" + f"do_report: think={content.think} {content.answer} status={status_enum} ret={ret}" ) self.last_report = request diff --git a/kag/solver/server/__init__.py b/kag/solver/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kag/solver/server/asyn_task_manager.py b/kag/solver/server/asyn_task_manager.py new file mode 100644 index 00000000..069c4e89 --- /dev/null +++ b/kag/solver/server/asyn_task_manager.py @@ -0,0 +1,168 @@ +import concurrent.futures +import queue +import threading +import time +import uuid +import logging +from cachetools import TTLCache + +logger = logging.getLogger() + + +class AsyncTaskManager: + def __init__(self, max_workers=10, ttl=3600): + """ + Initialize async task manager + + Args: + max_workers (int): Maximum number of worker threads + ttl (int): Time-to-live for task results in seconds + """ + self.max_workers = max_workers + self.task_queue = queue.Queue() + self.result_cache = TTLCache(maxsize=1000, ttl=ttl) + self.result_cache_lock = threading.Lock() # Protect cache from race conditions + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + self.workers = [ + threading.Thread(target=self.worker, daemon=True) + for _ in range(max_workers) + ] + for w in self.workers: + w.start() + + def worker(self): + """Worker thread main loop that processes tasks""" + while True: + try: + # Get next task from queue with timeout to allow shutdown detection + task = self.task_queue.get() + task_id, func, args, kwargs = task + logger.info(f"Processing task {task_id}") + # finish flag + if task_id is None: + self.task_queue.task_done() + break + + # Update cache with running status + with self.result_cache_lock: + self.result_cache[task_id] = { + "task_id": task_id, + "status": "running", + "result": None, + } + + # Execute task + future = self.executor.submit(func, *args, **kwargs) + result = future.result() + status = "completed" + + except queue.Empty: + # Handle queue empty timeout (normal operation) + continue + + except Exception as e: + # Handle task execution errors + result = str(e) + status = "failed" + logger.error(f"Task {task_id} failed with error: {e}", exc_info=True) + + # Store final result in cache + try: + with self.result_cache_lock: + self.result_cache[task_id] = { + "task_id": task_id, + "status": status, + "result": result, + } + logger.info(f"Task {task_id} completed with status: {status}") + finally: + # Always mark task as done + self.task_queue.task_done() + + def submit_task(self, func, *args, **kwargs): + """ + Submit a new task to the queue + + Args: + func: Callable function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + str: Unique task ID + """ + task_id = str(uuid.uuid4()) + self.task_queue.put((task_id, func, args, kwargs)) + return task_id + + def get_task_result(self, task_id): + """ + Get result for a specific task + + Args: + task_id (str): Unique task identifier + + Returns: + dict: Task result information or expired status + """ + with self.result_cache_lock: + return self.result_cache.get( + task_id, + { + "task_id": task_id, + "status": "failed", + "result": "Result not found or expired", + }, + ) + + def shutdown(self): + """Gracefully shutdown all worker threads and executors""" + # Send shutdown signals + for _ in range(self.max_workers): + self.task_queue.put((None, None, (), {})) + + # Wait for queue to empty and workers to terminate + self.task_queue.join() + + # Shutdown executors + self.executor.shutdown(wait=True) + for worker in self.workers: + worker.join(timeout=5) + + +# Global async task manager instance +asyn_task = AsyncTaskManager() + +if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO) + + # Create task manager instance + task_manager = AsyncTaskManager(max_workers=5, ttl=600) + + # Example task function + def example_task(x, y): + time.sleep(1) # Simulate work + return x + + # Submit test tasks + task_ids = [task_manager.submit_task(example_task, i, i + 1) for i in range(6)] + + # Monitor task progress + try: + while True: + time.sleep(1) + if all( + "completed" in task_manager.get_task_result(tid)["status"] + for tid in task_ids + ): + break + except KeyboardInterrupt: + logger.info("Shutting down due to user interrupt") + + # Print results + for task_id in task_ids: + print(f"Task {task_id} result: {task_manager.get_task_result(task_id)}") + + # Clean up resources + task_manager.shutdown() diff --git a/kag/solver/server/example/__init__.py b/kag/solver/server/example/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kag/solver/server/main_server.py b/kag/solver/server/main_server.py new file mode 100644 index 00000000..755e3590 --- /dev/null +++ b/kag/solver/server/main_server.py @@ -0,0 +1,63 @@ +from fastapi import FastAPI +import uvicorn + +from kag.solver.main_solver import SolverMain +from kag.solver.server.asyn_task_manager import AsyncTaskManager +from kag.solver.server.model.task_req import FeatureRequest, TaskReq + + +def run_main_solver(task: TaskReq): + return SolverMain().invoke( + project_id=task.project_id, + task_id=task.req_id, + query=task.req.query, + is_report=task.req.report, + host_addr=task.req.host_addr, + app_id=task.app_id, + params=task.config, + ) + + +class KAGSolverServer: + def __init__(self, service_name: str): + """ + Initialize a FastAPI service instance + + Args: + service_name (str): Service name, determines which routing logic to load + """ + self.service_name = service_name + self.app = FastAPI(title=f"{service_name} API") + + # Bind routes according to service name + self._setup_routes() + self.async_manager = AsyncTaskManager() + + def sync_task(self, task: TaskReq): + if task.cmd == "submit": + + return { + "success": True, + "status": "init", + "result": self.async_manager.submit_task(run_main_solver, task), + } + elif task.cmd == "query": + return self.async_manager.get_task_result(task_id=task.req_id) + else: + return { + "success": False, + "status": "failed", + "result": f"invalid input cmd {task.cmd}", + } + + def _setup_routes(self): + """Dynamically bind routes according to service name""" + + @self.app.post("/process") + def process(req: FeatureRequest): + return self.sync_task(task=req.features.task_req) + + def run(self, host="0.0.0.0", port=8000): + """Start the service""" + print(f"Starting {self.service_name} service on {host}:{port}") + uvicorn.run(self.app, host=host, port=port) diff --git a/kag/solver/server/model/__init__.py b/kag/solver/server/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kag/solver/server/model/task_req.py b/kag/solver/server/model/task_req.py new file mode 100644 index 00000000..0a48bd3b --- /dev/null +++ b/kag/solver/server/model/task_req.py @@ -0,0 +1,122 @@ +import json +from typing import Optional + +from pydantic import BaseModel, model_validator, field_serializer + + +class ReqBody(BaseModel): + """Request body model containing query parameters""" + + query: str = "" + report: bool = True + host_addr: str = "" + + +class TaskReq(BaseModel): + """Task request model with validation logic""" + + app_id: int = "" + project_id: int = 0 + req_id: str = "" + cmd: str = "" + mode: str = "" + req: str = None + config: str = "{}" + + @model_validator(mode="after") + def parse_req_to_req_body(self): + """Parse req string to ReqBody object and process config field""" + try: + import json + + if isinstance(self.req, str): + req_body_dict = json.loads(self.req) + self.req = ReqBody(**req_body_dict) + if isinstance(self.config, str) and self.config: + config_dict = json.loads(self.config) + self.config = config_dict + except Exception as e: + raise ValueError(f"Failed to parse 'req' field to ReqBody: {e}") + return self + + @field_serializer("req") + def serialize_req(self, value: object) -> object: + """Serialize ReqBody back to JSON string""" + if isinstance(value, ReqBody): + return value.model_dump_json() + return value # Return as-is if already a string + + +# Request model with TaskReq parsing capability +class Request(BaseModel): + """Container model for task request data""" + + in_string: str + task_req: Optional[TaskReq] = None + + @model_validator(mode="after") + def parse_in_string_to_task_req(self): + """Convert in_string JSON string to TaskReq object""" + try: + import json + + task_req_dict = json.loads(self.in_string) + self.task_req = TaskReq(**task_req_dict) + except Exception as e: + raise ValueError(f"Invalid TaskReq JSON string: {e}") + return self + + +class FeatureRequest(BaseModel): + """Top-level request wrapper with features container""" + + features: Request + + +if __name__ == "__main__": + + def feature_request_parsing(): + """Demonstrate nested model parsing workflow""" + # Build innermost ReqBody JSON string + req_body = ReqBody( + query="阿里巴巴财报中,2024年-截至9月30日止六个月的收入是多少?其中云智能集团收入是多少?占比是多少", + report=True, + host_addr="https://spg.alipay.com", + ) + req_body_json = json.dumps(req_body.model_dump()) + + # Build TaskReq dictionary and serialize to string + task_req = TaskReq( + req_id="9400110", + cmd="submit", + mode="async", + req=req_body_json, + app_id="app_id", + project_id=4200050, + config={"timeout": 10}, + ) + task_req_json = json.dumps(task_req.model_dump()) + + # Construct final FeatureRequest JSON string + input_data = {"features": {"in_string": task_req_json}} + + # Deserialize to FeatureRequest model + feature_request = FeatureRequest(**input_data) + + # Validate in_string parsed to TaskReq + assert isinstance(feature_request.features.task_req, TaskReq) + assert feature_request.features.task_req.req_id == "abc123" + assert feature_request.features.task_req.cmd == "run" + assert feature_request.features.task_req.mode == "sync" + assert feature_request.features.task_req.config == {"timeout": 10} + + # Validate TaskReq.req parsed to ReqBody + req_body_parsed = feature_request.features.task_req.req + assert isinstance(req_body_parsed, ReqBody) + assert req_body_parsed.query == "What is AI?" + assert req_body_parsed.report is True + assert req_body_parsed.host_addr == "localhost" + + print("✅ All assertions passed!") + + feature_request_parsing() diff --git a/kag/solver/server/requirement.txt b/kag/solver/server/requirement.txt new file mode 100644 index 00000000..170703df --- /dev/null +++ b/kag/solver/server/requirement.txt @@ -0,0 +1 @@ +fastapi \ No newline at end of file diff --git a/knext/common/env.py b/knext/common/env.py index 6a9994c7..595c8d42 100644 --- a/knext/common/env.py +++ b/knext/common/env.py @@ -48,13 +48,16 @@ class Environment: @property def config(self): - closest_config = self._closest_config() - if not hasattr(self, "_config_path") or self._config_path != closest_config: - self._config_path = closest_config - self._config = self.get_config() + try: + closest_config = self._closest_config() + if not hasattr(self, "_config_path") or self._config_path != closest_config: + self._config_path = closest_config + self._config = self.get_config() - if self._config is None: - self._config = self.get_config() + if self._config is None: + self._config = self.get_config() + except: + return {} return self._config diff --git a/requirements.txt b/requirements.txt index 2085ea8d..44ed0d4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ numpy>=1.23.1 pypdf pandas pycryptodome -markdown +markdown==3.7 bs4 protobuf==3.20.1 neo4j diff --git a/tests/unit/common/kag_utils.py b/tests/unit/common/kag_utils.py new file mode 100644 index 00000000..5d18d853 --- /dev/null +++ b/tests/unit/common/kag_utils.py @@ -0,0 +1,64 @@ +from kag.common.utils import extract_tag_content + + +def run_extra_tag(): + test_cases = [ + { + "input": "abcedsome wordother tags", + "expected": [("tag1", "abced"), ("", "some word"), ("tag2", "other tags")], + "description": "基本闭合标签与无标签文本混合", + }, + { + "input": "

Hello world this is test", + "expected": [ + ("p", "Hello "), + ("b", "world"), + ("", " this is "), + ("i", "test"), + ], + "description": "混合闭合与未闭合标签", + }, + { + "input": "plain text without any tags", + "expected": [("", "plain text without any tags")], + "description": "纯文本无标签", + }, + { + "input": "

\n Line 1\n Line 2\n Line 3\n
", + "expected": [ + ("div", "\n Line 1\n Line 2\n Line 3\n") + ], + "description": "多行内容和空白处理", + }, + { + "input": "ABC", + "expected": [("a", "A"), ("b", "B"), ("c", "C")], + "description": "连续多个闭合标签", + }, + { + "input": "My DocumentThis is the content", + "expected": [("title", "My Document"), ("content", "This is the content")], + "description": "未闭合标签(EOF结尾)", + }, + { + "input": "Error: &*^%$#@!;End of log", + "expected": [("log", "Error: &*^%$#@!;"), ("note", "End of log")], + "description": "含特殊字符的内容", + }, + { + "input": "", + "expected": [], + "description": "空字符串输入", + }, + ] + + for i, test in enumerate(test_cases): + result = extract_tag_content(test["input"]) + assert ( + result == test["expected"] + ), f"Test {i+1} failed: {test['description']}\nGot: {result}\nExpected: {test['expected']}" + print(f"Test {i+1} passed: {test['description']}") + + +if __name__ == "__main__": + run_extra_tag()