diff --git a/KAG_VERSION b/KAG_VERSION
index a3df0a69..8adc70fd 100644
--- a/KAG_VERSION
+++ b/KAG_VERSION
@@ -1 +1 @@
-0.8.0
+0.8.0
\ No newline at end of file
diff --git a/kag/common/utils.py b/kag/common/utils.py
index 7614eea7..40757bf9 100644
--- a/kag/common/utils.py
+++ b/kag/common/utils.py
@@ -463,9 +463,17 @@ def resolve_instance(
def extract_tag_content(text):
- # 匹配和之间的内容,支持任意标签名
- matches = re.findall(r"<([^>]+)>(.*?)\1>", text, flags=re.DOTALL)
- return [(tag, content.strip()) for tag, content in matches]
+ pattern = r"<(\w+)\b[^>]*>(.*?)\1>|<(\w+)\b[^>]*>([^<]*)|([^<]+)"
+ results = []
+ for match in re.finditer(pattern, text, re.DOTALL):
+ tag1, content1, tag2, content2, raw_text = match.groups()
+ if tag1:
+ results.append((tag1, content1)) # 保留原始内容(含空格)
+ elif tag2:
+ results.append((tag2, content2)) # 保留原始内容(含空格)
+ elif raw_text:
+ results.append(("", raw_text)) # 保留原始空格
+ return results
def extract_specific_tag_content(text, tag):
diff --git a/kag/solver/executor/math/py_based_math_executor.py b/kag/solver/executor/math/py_based_math_executor.py
index 544ff674..6c172bd0 100644
--- a/kag/solver/executor/math/py_based_math_executor.py
+++ b/kag/solver/executor/math/py_based_math_executor.py
@@ -131,9 +131,11 @@ class PyBasedMathExecutor(ExecutorABC):
)
parent_results = format_task_dep_context(task.parents)
- parent_results = "\n".join(parent_results)
+ coder_content = context.kwargs.get("planner_thought", "") + "\n\n".join(
+ parent_results
+ )
- parent_results += "\n\n" + contents
+ coder_content += "\n\n" + contents
tries = self.tries
error = None
@@ -141,7 +143,7 @@ class PyBasedMathExecutor(ExecutorABC):
tries -= 1
rst, error, code = self.run_once(
math_query,
- parent_results,
+ coder_content,
error,
segment_name=tag_id,
tag_name=f"{task_query}_code_generator",
diff --git a/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py b/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py
index 455f1eeb..721f8323 100644
--- a/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py
+++ b/kag/solver/executor/retriever/kag_hybrid_retrieval_executor.py
@@ -42,6 +42,15 @@ from kag.solver.utils import init_prompt_with_fallback
logger = logging.getLogger()
+def _wrapped_invoke(retriever, task, context, segment_name, kwargs):
+ start_time = time.time()
+ output = retriever.invoke(
+ task, context=context, segment_name=segment_name, **kwargs
+ )
+ elapsed_time = time.time() - start_time
+ return output, elapsed_time
+
+
@ExecutorABC.register("kag_hybrid_retrieval_executor")
class KAGHybridRetrievalExecutor(ExecutorABC):
def __init__(
@@ -76,6 +85,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
self.context_select_prompt = context_select_prompt or PromptABC.from_config(
{"type": "context_select_prompt"}
)
+ self.with_llm_select = kwargs.get("with_llm_select", True)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1))
def context_select_call(self, variables):
@@ -152,22 +162,30 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
"FINISH",
component_name=retriever.name,
)
-
+ # Record start time before submitting the task
+ start_time = time.time()
# Prepare function and submit to thread pool
func = partial(
- retriever.invoke,
+ _wrapped_invoke,
+ retriever,
task,
- context=context,
- segment_name=tag_id,
- **kwargs,
+ context,
+ tag_id,
+ kwargs.copy(),
)
future = executor.submit(func)
+ # Save future, retriever, and start_time together
futures.append((future, retriever))
# Collect results from each future
for future, retriever in futures:
try:
- output = future.result() # Wait for result
+ output, elapsed_time = future.result() # Wait for result
+
+ # Log the elapsed time for this retriever
+ logger.info(
+ f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds"
+ )
outputs.append(output)
# Log data report after successful execution
@@ -241,13 +259,18 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
selected_rel = list(set(selected_rel))
formatted_docs = [str(rel) for rel in selected_rel]
if retrieved_data.chunks:
- try:
- selected_chunks = self.context_select(task_query, retrieved_data.chunks)
- except Exception as e:
- logger.warning(
- f"select context failed {e}, we use default top 10 to summary",
- exc_info=True,
- )
+ if self.with_llm_select:
+ try:
+ selected_chunks = self.context_select(
+ task_query, retrieved_data.chunks
+ )
+ except Exception as e:
+ logger.warning(
+ f"select context failed {e}, we use default top 10 to summary",
+ exc_info=True,
+ )
+ selected_chunks = retrieved_data.chunks[:10]
+ else:
selected_chunks = retrieved_data.chunks[:10]
for doc in selected_chunks:
formatted_docs.append(f"{doc.content}")
@@ -280,68 +303,81 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
task_query = task.arguments["query"]
tag_id = f"{task_query}_begin_task"
- self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name)
+ self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name)
try:
- retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs)
- except Exception as e:
- logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
- retrieved_data = RetrieverOutput(
- retriever_method=self.schema().get("name", ""), err_msg=str(e)
- )
-
- self.report_content(
- reporter,
- "reference",
- f"{task_query}_kag_retriever_result",
- retrieved_data,
- "FINISH",
- )
-
- retrieved_data.task = task
- logical_node = task.arguments.get("logic_form_node", None)
- if (
- logical_node
- and isinstance(logical_node, GetSPONode)
- and retrieved_data.summary
- ):
- if isinstance(retrieved_data.summary, str):
- target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
- s_entities = context.variables_graph.get_entity_by_alias(
- logical_node.s.alias_name
+ try:
+ retrieved_data = self.do_main(
+ task_query, tag_id, task, context, **kwargs
)
- if (
- not s_entities
- and not logical_node.s.get_mention_name()
- and isinstance(logical_node.s, SPOEntity)
- ):
- logical_node.s.entity_name = target_answer
- context.kwargs[logical_node.s.alias_name] = logical_node.s
- o_entities = context.variables_graph.get_entity_by_alias(
- logical_node.o.alias_name
+ except Exception as e:
+ logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
+ retrieved_data = RetrieverOutput(
+ retriever_method=self.schema().get("name", ""), err_msg=str(e)
)
- if (
- not o_entities
- and not logical_node.o.get_mention_name()
- and isinstance(logical_node.o, SPOEntity)
- ):
- logical_node.o.entity_name = target_answer
- context.kwargs[logical_node.o.alias_name] = logical_node.o
- context.variables_graph.add_answered_alias(
- logical_node.s.alias_name.alias_name, retrieved_data.summary
- )
- context.variables_graph.add_answered_alias(
- logical_node.p.alias_name.alias_name, retrieved_data.summary
- )
- context.variables_graph.add_answered_alias(
- logical_node.o.alias_name.alias_name, retrieved_data.summary
+ self.report_content(
+ reporter,
+ "reference",
+ f"{task_query}_kag_retriever_result",
+ retrieved_data,
+ "FINISH",
)
- task.update_result(retrieved_data)
- logger.debug(
- f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
- )
- return retrieved_data
+ retrieved_data.task = task
+ logical_node = task.arguments.get("logic_form_node", None)
+ if (
+ logical_node
+ and isinstance(logical_node, GetSPONode)
+ and retrieved_data.summary
+ ):
+ if isinstance(retrieved_data.summary, str):
+ target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
+ s_entities = context.variables_graph.get_entity_by_alias(
+ logical_node.s.alias_name
+ )
+ if (
+ not s_entities
+ and not logical_node.s.get_mention_name()
+ and isinstance(logical_node.s, SPOEntity)
+ ):
+ logical_node.s.entity_name = target_answer
+ context.kwargs[logical_node.s.alias_name] = logical_node.s
+ o_entities = context.variables_graph.get_entity_by_alias(
+ logical_node.o.alias_name
+ )
+ if (
+ not o_entities
+ and not logical_node.o.get_mention_name()
+ and isinstance(logical_node.o, SPOEntity)
+ ):
+ logical_node.o.entity_name = target_answer
+ context.kwargs[logical_node.o.alias_name] = logical_node.o
+
+ context.variables_graph.add_answered_alias(
+ logical_node.s.alias_name.alias_name, retrieved_data.summary
+ )
+ context.variables_graph.add_answered_alias(
+ logical_node.p.alias_name.alias_name, retrieved_data.summary
+ )
+ context.variables_graph.add_answered_alias(
+ logical_node.o.alias_name.alias_name, retrieved_data.summary
+ )
+
+ task.update_result(retrieved_data)
+ logger.debug(
+ f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
+ )
+ return retrieved_data
+ finally:
+ self.report_content(
+ reporter,
+ "thinker",
+ tag_id,
+ "",
+ "FINISH",
+ step=task.name,
+ overwrite=False,
+ )
def schema(self) -> dict:
"""Function schema definition for OpenAI Function Calling
@@ -403,7 +439,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
node_type=chunk.properties.get("__labels__"),
)
entity_prop = dict(chunk.properties) if chunk.properties else {}
- entity_prop["content"] = chunk.content
+ entity_prop["content"] = f"{chunk.content[:10]}..."
entity_prop["score"] = chunk.score
entity.prop = Prop.from_dict(entity_prop, "Chunk", None)
chunk_graph.append(entity)
diff --git a/kag/solver/main_solver.py b/kag/solver/main_solver.py
index f2fb4b16..32120a3a 100644
--- a/kag/solver/main_solver.py
+++ b/kag/solver/main_solver.py
@@ -140,8 +140,6 @@ def get_pipeline_conf(use_pipeline_name, config):
raise RuntimeError("mcpServers not found in config.")
default_solver_pipeline["executors"] = mcp_executors
- # update KAG_CONFIG
- KAG_CONFIG.update_conf(default_pipeline_conf)
return default_solver_pipeline
@@ -167,8 +165,11 @@ async def do_qa_pipeline(
f"Knowledge base with id {kb_project_id} not found in qa_config['kb']"
)
continue
-
- for index_name in matched_kb.get("index_list", []):
+ index_list = matched_kb.get("index_list", [])
+ if use_pipeline in ["default_pipeline"]:
+ # we only use chunk index
+ index_list = ["chunk_index"]
+ for index_name in index_list:
index_manager = KAGIndexManager.from_config(
{
"type": index_name,
@@ -339,7 +340,7 @@ class SolverMain:
def invoke(
self,
project_id: int,
- task_id: int,
+ task_id,
query: str,
session_id: str = "0",
is_report=True,
diff --git a/kag/solver/pipelineconf/naive_rag.yaml b/kag/solver/pipelineconf/naive_rag.yaml
index 2cc249e4..84d4b8a6 100644
--- a/kag/solver/pipelineconf/naive_rag.yaml
+++ b/kag/solver/pipelineconf/naive_rag.yaml
@@ -3,20 +3,17 @@ pipeline_name: default_pipeline
#------------kag-solver configuration start----------------#
-
-chunk_retrieved_executor: &chunk_retrieved_executor_conf
- type: chunk_retrieved_executor
- top_k: 10
- retriever:
- type: vector_chunk_retriever
- score_threshold: 0.65
- vectorize_model: "{vectorize_model}"
-
+kag_retriever_executor: &kag_retriever_executor_conf
+ type: kag_hybrid_retrieval_executor
+ retrievers: "{retrievers}"
+ merger:
+ type: kag_merger
+ enable_summary: false
solver_pipeline:
type: naive_rag_pipeline
executors:
- - *chunk_retrieved_executor_conf
+ - *kag_retriever_executor_conf
generator:
type: llm_index_generator
llm_client: "{chat_llm}"
diff --git a/kag/solver/planner/kag_model_planner.py b/kag/solver/planner/kag_model_planner.py
index 62bcd6c2..008ac3c1 100644
--- a/kag/solver/planner/kag_model_planner.py
+++ b/kag/solver/planner/kag_model_planner.py
@@ -186,6 +186,7 @@ class KAGModelPlanner(PlannerABC):
.replace("", "")
.strip()
)
+ context.kwargs["planner_thought"] = logic_form_response
sub_queries, logic_forms = parse_logic_form_with_str(logic_form_str)
logic_forms = self.logic_node_parser.parse_logic_form_set(
diff --git a/kag/solver/prompt/expression_builder.py b/kag/solver/prompt/expression_builder.py
index 4002bb48..36c25dbd 100644
--- a/kag/solver/prompt/expression_builder.py
+++ b/kag/solver/prompt/expression_builder.py
@@ -11,42 +11,49 @@ logger = logging.getLogger(__name__)
class ExpressionBuildr(PromptABC):
template_zh = (
f"今天是{get_now(language='zh')}。"
- + """\n# instruction
+ + """
+# instruction
根据给出的问题和数据,编写python代码,输出问题结果。
-为了便于理解,输出从context中提取的数据,输出中间计算过程和结果。
-注意严格根据输入内容进行编写代码,不允许进行假设
-例如伤残等级如果context中未提及,则认为没有被认定为残疾
-如果无法回答问题,直接返回:I don't know.
+从context中提取的数据必须显式赋值,所有计算步骤必须用代码实现,不得隐含推断。
+必须输出中间计算过程和结果,格式为print语句。
+如果context未提供必要数据或无法计算,直接打印"I don't know."
# output format
-直接输出python代码,python版本为3.10,不要包含任何其他信息
+严格输出以下结构的python代码(版本3.10):
+1. 数据提取部分:代码中涉及输入的数值需要从context及question中提取,不允许进行假设
+2. 计算过程:分步实现所有数学运算,每个步骤对应独立变量
+3. 输出:每个中间变量和最终结果必须用print语句输出
# examples
## 例子1
### input
#### question
-47000元按照万分之1.5一共612天,计算利息,一共多少钱?
+4百万元按照日利率万分之1.5,一共612天,计算利息,一共多少钱?
+#### context
+日利率万分之1.5
### output
```python
-# 初始本金
-principal = 47000
+# 初始本金(单位:百万)
+principal = 4 # 单位:百万
-# 利率(万分之1.5)
-rate = 1.5 / 10000
+# 日利率计算(万分之1.5)
+daily_rate = 1.5 / 10000
-# 天数
+# 计算周期
days = 612
-# 计算年利率
-annual_rate = rate * 365
+# 单日利息计算
+daily_interest = principal * daily_rate
-# 计算利息
-interest = principal * (annual_rate / 365) * days
+# 累计利息计算
+total_interest = daily_interest * days
-# 输出总金额(本金+利息)
-total_amount = principal + interest
+# 总金额计算
+total_amount = principal + total_interest
-print(f"总金额:{total_amount:.2f}元")
+print(f"单日利息:{daily_interest:.2f}百万")
+print(f"累计利息:{total_interest:.2f}百万")
+print(f"总金额:{total_amount:.2f}百万")
```
## 例子2
@@ -70,13 +77,26 @@ revenue_2020 = revenue_2019 * (1 + growth_rate)
print(f"2020年的预计收入为: {revenue_2020:.2f}万")
```
+## 例子3
+### input
+#### question
+47000元按照612天计算利息,本息一共多少钱?
+#### content
+
+### output
+```python
+# 未给出利率,无法计算
+print("未给出利率,无法计算")
+```
# input
## question
$question
## context
$context
## error
- $error"""
+$error
+## output
+"""
)
template_en = (
f"Today is {get_now(language='en')}。\n"
diff --git a/kag/solver/reporter/open_spg_kag_model_reporter.py b/kag/solver/reporter/open_spg_kag_model_reporter.py
index e5ceae7d..d569a16b 100644
--- a/kag/solver/reporter/open_spg_kag_model_reporter.py
+++ b/kag/solver/reporter/open_spg_kag_model_reporter.py
@@ -1,5 +1,6 @@
import logging
import re
+import time
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.parser.logic_node_parser import extract_steps_and_actions
@@ -72,14 +73,16 @@ def process_tag_template(text):
}
clean_text = ""
for tag_info in all_tags:
+ content = tag_info[1]
if tag_info[0] in xml_tag_template:
- content = tag_info[1]
if "search" == tag_info[0]:
content = process_planning(content)
clean_text += xml_tag_template[tag_info[0]][
KAG_PROJECT_CONF.language
].format_map(SafeDict({"content": content}))
- return remove_xml_tags(clean_text)
+ else:
+ clean_text += content
+ text = remove_xml_tags(clean_text)
return text
diff --git a/kag/solver/reporter/open_spg_reporter.py b/kag/solver/reporter/open_spg_reporter.py
index ef2ae01a..8d0d9c95 100644
--- a/kag/solver/reporter/open_spg_reporter.py
+++ b/kag/solver/reporter/open_spg_reporter.py
@@ -159,12 +159,12 @@ def render_jinja2_template(template_str, context):
"""
try:
template = Template(template_str, undefined=SilentUndefined)
- return template.render(**context).strip()
+ return template.render(**context)
except Exception as e:
logging.error(
f"Jinja2 rendering failed: {e}, Original template: {template_str}"
)
- return template_str.strip() # Fallback to raw template string on failure
+ return template_str # Fallback to raw template string on failure
@ReporterABC.register("open_spg_reporter")
@@ -264,12 +264,12 @@ Rerank the documents and take the top {{ chunk_num }}.
}
self.tag_mapping = {
"Graph Show": {
- "en": "{content}",
- "zh": "{content}",
+ "en": "{{ content }}",
+ "zh": "{{ content }}",
},
"Rewrite query": {
- "en": "Rethinking question using LLM: {content}",
- "zh": "根据依赖问题重写子问题: {content}",
+ "en": "Rethinking question using LLM: {{ content }}",
+ "zh": "根据依赖问题重写子问题: {{ content }}",
},
"language_setting": {
"en": "",
@@ -277,125 +277,153 @@ Rerank the documents and take the top {{ chunk_num }}.
},
"Iterative planning": {
"en": """
-
+
-{content}
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
"zh": """
-
+
-{content}
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
},
"Static planning": {
"en": """
-
+
-{content}
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
"zh": """
-
+
-{content}
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
},
"begin_sub_kag_retriever": {
- "en": "Starting {component_name}: {content} {desc}",
- "zh": "执行{component_name}: {content} {desc}",
+ "en": "Starting {{component_name}}: {{content}} {{desc}}",
+ "zh": "执行{{component_name}}: {{content}} {{desc}}",
},
"end_sub_kag_retriever": {
- "en": " {content}",
- "zh": " {content}",
+ "en": " {{ content }}",
+ "zh": " {{ content }}",
},
"rc_retriever_rewrite": {
"en": """
-
+
-Rewritten question:\n{content}
+Rewritten question:
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
"zh": """
-
+
-重写问题为:\n\n{content}
+重写问题为:
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
},
"rc_retriever_summary": {
- "en": "Summarizing retrieved documents,{content}",
- "zh": "对文档进行总结,{content}",
+ "en": "Summarizing retrieved documents,{{ content }}",
+ "zh": "对文档进行总结,{{ content }}",
},
"kg_retriever_summary": {
- "en": "Summarizing retrieved graph,{content}",
- "zh": "对召回的知识进行总结,{content}",
+ "en": "Summarizing retrieved graph,{{ content }}",
+ "zh": "对召回的知识进行总结,{{ content }}",
},
"retriever_summary": {
- "en": "Summarizing retrieved documents,{content}",
- "zh": "对文档进行总结,{content}",
+ "en": "Summarizing retrieved documents,{{ content }}",
+ "zh": "对文档进行总结,{{ content }}",
},
"begin_summary": {
- "en": "Summarizing retrieved information, {content}",
- "zh": "对检索的信息进行总结, {content}",
+ "en": "Summarizing retrieved information, {{ content }}",
+ "zh": "对检索的信息进行总结, {{ content }}",
},
"begin_task": {
"en": """
-
+
-{content}
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
"zh": """
-
+
-{content}
+{{ content }}
-""",
+{% if status == 'success' %}
+
+{% endif %}""",
},
"logic_node": {
"en": """Translate query to logic form expression
```json
-{content}
+{{ content }}
```""",
"zh": """将query转换成逻辑形式表达
```json
-{content}
+{{ content }}
```""",
},
"kag_retriever_result": {
- "en": "Retrieved documents\n\n{content}",
- "zh": "检索到的文档\n\n{content}",
+ "en": """Retrieved documents
+{{ content }}""",
+ "zh": """检索到的文档
+{{ content }}""",
},
"failed_kag_retriever": {
"en": """KAG retriever failed
```json
-{content}
+{{ content }}
```
""",
"zh": """KAG检索失败
```json
-{content}
+{{ content }}
```
""",
},
"end_math_executor": {
- "en": "Math executor completed\n\n{content}",
- "zh": "计算结束\n\n{content}",
+ "en": """Math executor completed
+{{ content }}""",
+ "zh": """计算结束
+{{ content }}""",
},
"code_generator": {
- "en": "Generating code\n \n{content}\n",
- "zh": "正在生成代码\n \n{content}\n",
+ "en": """Generating code
+{{ content }}
+
+""",
+ "zh": """正在生成代码
+{{ content }}
+
+""",
},
}
task_id = kwargs.get(KAGConstants.KAG_QA_TASK_CONFIG_KEY, None)
@@ -425,7 +453,11 @@ Rewritten question:\n{content}
if tpl:
format_params = {"content": datas}
format_params.update(content_params)
- datas = tpl.format_map(SafeDict(format_params))
+ if "{" in tpl or "%}" in tpl:
+ rendered = render_jinja2_template(tpl, format_params)
+ else:
+ rendered = tpl.format_map(SafeDict(format_params))
+ datas = rendered
elif str(datas).strip() != "":
output = str(datas).strip()
if output != "":
@@ -516,7 +548,7 @@ Rewritten question:\n{content}
if self.last_report.to_dict() == request.to_dict():
return
logger.info(
- f"do_report: {content.answer} think={content.think} status={status_enum} ret={ret}"
+ f"do_report: think={content.think} {content.answer} status={status_enum} ret={ret}"
)
self.last_report = request
diff --git a/kag/solver/server/__init__.py b/kag/solver/server/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/kag/solver/server/asyn_task_manager.py b/kag/solver/server/asyn_task_manager.py
new file mode 100644
index 00000000..069c4e89
--- /dev/null
+++ b/kag/solver/server/asyn_task_manager.py
@@ -0,0 +1,168 @@
+import concurrent.futures
+import queue
+import threading
+import time
+import uuid
+import logging
+from cachetools import TTLCache
+
+logger = logging.getLogger()
+
+
+class AsyncTaskManager:
+ def __init__(self, max_workers=10, ttl=3600):
+ """
+ Initialize async task manager
+
+ Args:
+ max_workers (int): Maximum number of worker threads
+ ttl (int): Time-to-live for task results in seconds
+ """
+ self.max_workers = max_workers
+ self.task_queue = queue.Queue()
+ self.result_cache = TTLCache(maxsize=1000, ttl=ttl)
+ self.result_cache_lock = threading.Lock() # Protect cache from race conditions
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
+ self.workers = [
+ threading.Thread(target=self.worker, daemon=True)
+ for _ in range(max_workers)
+ ]
+ for w in self.workers:
+ w.start()
+
+ def worker(self):
+ """Worker thread main loop that processes tasks"""
+ while True:
+ try:
+ # Get next task from queue with timeout to allow shutdown detection
+ task = self.task_queue.get()
+ task_id, func, args, kwargs = task
+ logger.info(f"Processing task {task_id}")
+ # finish flag
+ if task_id is None:
+ self.task_queue.task_done()
+ break
+
+ # Update cache with running status
+ with self.result_cache_lock:
+ self.result_cache[task_id] = {
+ "task_id": task_id,
+ "status": "running",
+ "result": None,
+ }
+
+ # Execute task
+ future = self.executor.submit(func, *args, **kwargs)
+ result = future.result()
+ status = "completed"
+
+ except queue.Empty:
+ # Handle queue empty timeout (normal operation)
+ continue
+
+ except Exception as e:
+ # Handle task execution errors
+ result = str(e)
+ status = "failed"
+ logger.error(f"Task {task_id} failed with error: {e}", exc_info=True)
+
+ # Store final result in cache
+ try:
+ with self.result_cache_lock:
+ self.result_cache[task_id] = {
+ "task_id": task_id,
+ "status": status,
+ "result": result,
+ }
+ logger.info(f"Task {task_id} completed with status: {status}")
+ finally:
+ # Always mark task as done
+ self.task_queue.task_done()
+
+ def submit_task(self, func, *args, **kwargs):
+ """
+ Submit a new task to the queue
+
+ Args:
+ func: Callable function to execute
+ *args: Positional arguments for the function
+ **kwargs: Keyword arguments for the function
+
+ Returns:
+ str: Unique task ID
+ """
+ task_id = str(uuid.uuid4())
+ self.task_queue.put((task_id, func, args, kwargs))
+ return task_id
+
+ def get_task_result(self, task_id):
+ """
+ Get result for a specific task
+
+ Args:
+ task_id (str): Unique task identifier
+
+ Returns:
+ dict: Task result information or expired status
+ """
+ with self.result_cache_lock:
+ return self.result_cache.get(
+ task_id,
+ {
+ "task_id": task_id,
+ "status": "failed",
+ "result": "Result not found or expired",
+ },
+ )
+
+ def shutdown(self):
+ """Gracefully shutdown all worker threads and executors"""
+ # Send shutdown signals
+ for _ in range(self.max_workers):
+ self.task_queue.put((None, None, (), {}))
+
+ # Wait for queue to empty and workers to terminate
+ self.task_queue.join()
+
+ # Shutdown executors
+ self.executor.shutdown(wait=True)
+ for worker in self.workers:
+ worker.join(timeout=5)
+
+
+# Global async task manager instance
+asyn_task = AsyncTaskManager()
+
+if __name__ == "__main__":
+ # Configure logging
+ logging.basicConfig(level=logging.INFO)
+
+ # Create task manager instance
+ task_manager = AsyncTaskManager(max_workers=5, ttl=600)
+
+ # Example task function
+ def example_task(x, y):
+ time.sleep(1) # Simulate work
+ return x
+
+ # Submit test tasks
+ task_ids = [task_manager.submit_task(example_task, i, i + 1) for i in range(6)]
+
+ # Monitor task progress
+ try:
+ while True:
+ time.sleep(1)
+ if all(
+ "completed" in task_manager.get_task_result(tid)["status"]
+ for tid in task_ids
+ ):
+ break
+ except KeyboardInterrupt:
+ logger.info("Shutting down due to user interrupt")
+
+ # Print results
+ for task_id in task_ids:
+ print(f"Task {task_id} result: {task_manager.get_task_result(task_id)}")
+
+ # Clean up resources
+ task_manager.shutdown()
diff --git a/kag/solver/server/example/__init__.py b/kag/solver/server/example/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/kag/solver/server/main_server.py b/kag/solver/server/main_server.py
new file mode 100644
index 00000000..755e3590
--- /dev/null
+++ b/kag/solver/server/main_server.py
@@ -0,0 +1,63 @@
+from fastapi import FastAPI
+import uvicorn
+
+from kag.solver.main_solver import SolverMain
+from kag.solver.server.asyn_task_manager import AsyncTaskManager
+from kag.solver.server.model.task_req import FeatureRequest, TaskReq
+
+
+def run_main_solver(task: TaskReq):
+ return SolverMain().invoke(
+ project_id=task.project_id,
+ task_id=task.req_id,
+ query=task.req.query,
+ is_report=task.req.report,
+ host_addr=task.req.host_addr,
+ app_id=task.app_id,
+ params=task.config,
+ )
+
+
+class KAGSolverServer:
+ def __init__(self, service_name: str):
+ """
+ Initialize a FastAPI service instance
+
+ Args:
+ service_name (str): Service name, determines which routing logic to load
+ """
+ self.service_name = service_name
+ self.app = FastAPI(title=f"{service_name} API")
+
+ # Bind routes according to service name
+ self._setup_routes()
+ self.async_manager = AsyncTaskManager()
+
+ def sync_task(self, task: TaskReq):
+ if task.cmd == "submit":
+
+ return {
+ "success": True,
+ "status": "init",
+ "result": self.async_manager.submit_task(run_main_solver, task),
+ }
+ elif task.cmd == "query":
+ return self.async_manager.get_task_result(task_id=task.req_id)
+ else:
+ return {
+ "success": False,
+ "status": "failed",
+ "result": f"invalid input cmd {task.cmd}",
+ }
+
+ def _setup_routes(self):
+ """Dynamically bind routes according to service name"""
+
+ @self.app.post("/process")
+ def process(req: FeatureRequest):
+ return self.sync_task(task=req.features.task_req)
+
+ def run(self, host="0.0.0.0", port=8000):
+ """Start the service"""
+ print(f"Starting {self.service_name} service on {host}:{port}")
+ uvicorn.run(self.app, host=host, port=port)
diff --git a/kag/solver/server/model/__init__.py b/kag/solver/server/model/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/kag/solver/server/model/task_req.py b/kag/solver/server/model/task_req.py
new file mode 100644
index 00000000..0a48bd3b
--- /dev/null
+++ b/kag/solver/server/model/task_req.py
@@ -0,0 +1,122 @@
+import json
+from typing import Optional
+
+from pydantic import BaseModel, model_validator, field_serializer
+
+
+class ReqBody(BaseModel):
+ """Request body model containing query parameters"""
+
+ query: str = ""
+ report: bool = True
+ host_addr: str = ""
+
+
+class TaskReq(BaseModel):
+ """Task request model with validation logic"""
+
+ app_id: int = ""
+ project_id: int = 0
+ req_id: str = ""
+ cmd: str = ""
+ mode: str = ""
+ req: str = None
+ config: str = "{}"
+
+ @model_validator(mode="after")
+ def parse_req_to_req_body(self):
+ """Parse req string to ReqBody object and process config field"""
+ try:
+ import json
+
+ if isinstance(self.req, str):
+ req_body_dict = json.loads(self.req)
+ self.req = ReqBody(**req_body_dict)
+ if isinstance(self.config, str) and self.config:
+ config_dict = json.loads(self.config)
+ self.config = config_dict
+ except Exception as e:
+ raise ValueError(f"Failed to parse 'req' field to ReqBody: {e}")
+ return self
+
+ @field_serializer("req")
+ def serialize_req(self, value: object) -> object:
+ """Serialize ReqBody back to JSON string"""
+ if isinstance(value, ReqBody):
+ return value.model_dump_json()
+ return value # Return as-is if already a string
+
+
+# Request model with TaskReq parsing capability
+class Request(BaseModel):
+ """Container model for task request data"""
+
+ in_string: str
+ task_req: Optional[TaskReq] = None
+
+ @model_validator(mode="after")
+ def parse_in_string_to_task_req(self):
+ """Convert in_string JSON string to TaskReq object"""
+ try:
+ import json
+
+ task_req_dict = json.loads(self.in_string)
+ self.task_req = TaskReq(**task_req_dict)
+ except Exception as e:
+ raise ValueError(f"Invalid TaskReq JSON string: {e}")
+ return self
+
+
+class FeatureRequest(BaseModel):
+ """Top-level request wrapper with features container"""
+
+ features: Request
+
+
+if __name__ == "__main__":
+
+ def feature_request_parsing():
+ """Demonstrate nested model parsing workflow"""
+ # Build innermost ReqBody JSON string
+ req_body = ReqBody(
+ query="阿里巴巴财报中,2024年-截至9月30日止六个月的收入是多少?其中云智能集团收入是多少?占比是多少",
+ report=True,
+ host_addr="https://spg.alipay.com",
+ )
+ req_body_json = json.dumps(req_body.model_dump())
+
+ # Build TaskReq dictionary and serialize to string
+ task_req = TaskReq(
+ req_id="9400110",
+ cmd="submit",
+ mode="async",
+ req=req_body_json,
+ app_id="app_id",
+ project_id=4200050,
+ config={"timeout": 10},
+ )
+ task_req_json = json.dumps(task_req.model_dump())
+
+ # Construct final FeatureRequest JSON string
+ input_data = {"features": {"in_string": task_req_json}}
+
+ # Deserialize to FeatureRequest model
+ feature_request = FeatureRequest(**input_data)
+
+ # Validate in_string parsed to TaskReq
+ assert isinstance(feature_request.features.task_req, TaskReq)
+ assert feature_request.features.task_req.req_id == "abc123"
+ assert feature_request.features.task_req.cmd == "run"
+ assert feature_request.features.task_req.mode == "sync"
+ assert feature_request.features.task_req.config == {"timeout": 10}
+
+ # Validate TaskReq.req parsed to ReqBody
+ req_body_parsed = feature_request.features.task_req.req
+ assert isinstance(req_body_parsed, ReqBody)
+ assert req_body_parsed.query == "What is AI?"
+ assert req_body_parsed.report is True
+ assert req_body_parsed.host_addr == "localhost"
+
+ print("✅ All assertions passed!")
+
+ feature_request_parsing()
diff --git a/kag/solver/server/requirement.txt b/kag/solver/server/requirement.txt
new file mode 100644
index 00000000..170703df
--- /dev/null
+++ b/kag/solver/server/requirement.txt
@@ -0,0 +1 @@
+fastapi
\ No newline at end of file
diff --git a/knext/common/env.py b/knext/common/env.py
index 6a9994c7..595c8d42 100644
--- a/knext/common/env.py
+++ b/knext/common/env.py
@@ -48,13 +48,16 @@ class Environment:
@property
def config(self):
- closest_config = self._closest_config()
- if not hasattr(self, "_config_path") or self._config_path != closest_config:
- self._config_path = closest_config
- self._config = self.get_config()
+ try:
+ closest_config = self._closest_config()
+ if not hasattr(self, "_config_path") or self._config_path != closest_config:
+ self._config_path = closest_config
+ self._config = self.get_config()
- if self._config is None:
- self._config = self.get_config()
+ if self._config is None:
+ self._config = self.get_config()
+ except:
+ return {}
return self._config
diff --git a/requirements.txt b/requirements.txt
index 2085ea8d..44ed0d4b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,7 +22,7 @@ numpy>=1.23.1
pypdf
pandas
pycryptodome
-markdown
+markdown==3.7
bs4
protobuf==3.20.1
neo4j
diff --git a/tests/unit/common/kag_utils.py b/tests/unit/common/kag_utils.py
new file mode 100644
index 00000000..5d18d853
--- /dev/null
+++ b/tests/unit/common/kag_utils.py
@@ -0,0 +1,64 @@
+from kag.common.utils import extract_tag_content
+
+
+def run_extra_tag():
+ test_cases = [
+ {
+ "input": "abcedsome wordother tags",
+ "expected": [("tag1", "abced"), ("", "some word"), ("tag2", "other tags")],
+ "description": "基本闭合标签与无标签文本混合",
+ },
+ {
+ "input": "
Hello world this is test",
+ "expected": [
+ ("p", "Hello "),
+ ("b", "world"),
+ ("", " this is "),
+ ("i", "test"),
+ ],
+ "description": "混合闭合与未闭合标签",
+ },
+ {
+ "input": "plain text without any tags",
+ "expected": [("", "plain text without any tags")],
+ "description": "纯文本无标签",
+ },
+ {
+ "input": "
\n Line 1\n Line 2\n Line 3\n
",
+ "expected": [
+ ("div", "\n Line 1\n Line 2\n Line 3\n")
+ ],
+ "description": "多行内容和空白处理",
+ },
+ {
+ "input": "ABC",
+ "expected": [("a", "A"), ("b", "B"), ("c", "C")],
+ "description": "连续多个闭合标签",
+ },
+ {
+ "input": "My DocumentThis is the content",
+ "expected": [("title", "My Document"), ("content", "This is the content")],
+ "description": "未闭合标签(EOF结尾)",
+ },
+ {
+ "input": "Error: &*^%$#@!;End of log",
+ "expected": [("log", "Error: &*^%$#@!;"), ("note", "End of log")],
+ "description": "含特殊字符的内容",
+ },
+ {
+ "input": "",
+ "expected": [],
+ "description": "空字符串输入",
+ },
+ ]
+
+ for i, test in enumerate(test_cases):
+ result = extract_tag_content(test["input"])
+ assert (
+ result == test["expected"]
+ ), f"Test {i+1} failed: {test['description']}\nGot: {result}\nExpected: {test['expected']}"
+ print(f"Test {i+1} passed: {test['description']}")
+
+
+if __name__ == "__main__":
+ run_extra_tag()