update recall for AFAC2024

This commit is contained in:
锦呈 2025-06-06 16:26:14 +08:00
parent 74b91abe8c
commit 6c5d8a5192
10 changed files with 315 additions and 38 deletions

View File

@ -20,7 +20,8 @@ from kag.interface import ExtractorABC, PromptABC
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from knext.schema.client import CHUNK_TYPE
from kag.interface.common.model.chunk import ChunkTypeEnum
from knext.schema.client import CHUNK_TYPE, TABLE_TYPE
from knext.common.base.runnable import Input, Output
logger = logging.getLogger(__name__)
@ -80,6 +81,10 @@ class AtomicQueryExtractor(ExtractorABC):
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
if input.type == ChunkTypeEnum.Text:
o_label = CHUNK_TYPE
else:
o_label = TABLE_TYPE
title = input.name
passage = title + "\n" + input.content
try:
@ -105,7 +110,7 @@ class AtomicQueryExtractor(ExtractorABC):
s_label="AtomicQuery",
p="sourceChunk",
o_id=f"{chunk.id}",
o_label=CHUNK_TYPE,
o_label=o_label,
)
subgraph.id = chunk.id
return [subgraph]

View File

@ -5,7 +5,8 @@ from kag.interface import ExtractorABC
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from knext.schema.client import CHUNK_TYPE
from kag.interface.common.model.chunk import ChunkTypeEnum
from knext.schema.client import CHUNK_TYPE, TABLE_TYPE
from knext.common.base.runnable import Input, Output
logger = logging.getLogger(__name__)
@ -43,6 +44,10 @@ class OutlineExtractor(ExtractorABC):
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
if input.type == ChunkTypeEnum.Text:
o_label = CHUNK_TYPE
else:
o_label = TABLE_TYPE
if "/" not in input.name:
outline_name = input.name
else:
@ -79,7 +84,7 @@ class OutlineExtractor(ExtractorABC):
s_label="Outline",
p="sourceChunk",
o_id=f"{input.id}",
o_label=CHUNK_TYPE,
o_label=o_label,
properties={},
)

View File

@ -7,8 +7,9 @@ from kag.builder.prompt.utils import init_prompt_with_fallback
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.config import get_default_chat_llm_config
from kag.interface import ExtractorABC, LLMClient, PromptABC
from kag.interface.common.model.chunk import ChunkTypeEnum
from knext.common.base.runnable import Input, Output
from knext.schema.client import CHUNK_TYPE
from knext.schema.client import CHUNK_TYPE, TABLE_TYPE
logger = logging.getLogger(__name__)
@ -53,6 +54,11 @@ class SummaryExtractor(ExtractorABC):
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""
if input.type == ChunkTypeEnum.Text:
o_label = CHUNK_TYPE
else:
o_label = TABLE_TYPE
if "/" not in input.name:
summary_name = input.name
else:
@ -98,7 +104,7 @@ class SummaryExtractor(ExtractorABC):
s_label="Summary",
p="sourceChunk",
o_id=f"{input.id}",
o_label=CHUNK_TYPE,
o_label=o_label,
properties={},
)

View File

@ -1,6 +1,7 @@
from typing import List
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import re
from .LLMJudger import LLMJudger
from .evaUtils import get_em_f1
@ -70,18 +71,182 @@ class Evaluate:
ret.append(content)
return ret
def recall_top(self, predictionlist: list, goldlist: List[str]):
def _extract_text_only(self, text, mode="all"):
"""
Extract text content only, removing all symbols including newlines
Parameters:
text: input text
mode: extraction mode ('all', 'chinese_only')
- 'all': extract Chinese characters, English letters and numbers
- 'chinese_only': extract only Chinese characters
Returns:
str: extracted text content (lowercase for English)
"""
if mode == "chinese_only":
# Use regex to keep only Chinese characters, remove all other content
text_only = re.sub(r"[^\u4e00-\u9fa5]", "", str(text))
return text_only
else:
# Use regex to keep only Chinese characters, English letters and numbers, remove all symbols (including newlines, spaces, etc.)
text_only = re.sub(r"[^\u4e00-\u9fa5a-zA-Z0-9]", "", str(text))
# Convert to lowercase
return text_only.lower()
def _tokenize_chinese(self, text):
"""
Tokenize Chinese text into word set
Parameters:
text: Chinese text
Returns:
set: set of Chinese words/characters
"""
try:
# Try to use jieba for word segmentation
import jieba
words = jieba.lcut(text)
# Filter out single characters and short words, keep meaningful words
meaningful_words = [word.strip() for word in words if len(word.strip()) > 1]
# If no meaningful words found, fall back to character-level
if not meaningful_words:
meaningful_words = [char for char in text if char.strip()]
return set(meaningful_words)
except ImportError:
# Fallback to character-level tokenization if jieba not available
return set([char for char in text if char.strip()])
def _calculate_set_similarity(self, set1, set2):
"""
Calculate similarity between two sets using Jaccard similarity
Parameters:
set1: first set
set2: second set
Returns:
float: Jaccard similarity (0.0 to 1.0)
"""
if not set1 and not set2:
return 1.0
if not set1 or not set2:
return 0.0
intersection = len(set1.intersection(set2))
union = len(set1.union(set2))
return intersection / union if union > 0 else 0.0
def fuzzy_intersection(
self, gold_set, prediction_set, fuzzy_mode="all", similarity_threshold=0.9
):
"""
Calculate intersection using fuzzy matching, only matching text content and ignoring symbols
Parameters:
gold_set: ground truth set
prediction_set: prediction result set
fuzzy_mode: text extraction mode ('all', 'chinese_only')
similarity_threshold: minimum similarity threshold for chinese_only mode (0.0 to 1.0)
Returns:
int: number of matches
"""
matches = 0
for gold_item in gold_set:
gold_text = self._extract_text_only(gold_item, fuzzy_mode)
for pred_item in prediction_set:
pred_text = self._extract_text_only(pred_item, fuzzy_mode)
if fuzzy_mode == "chinese_only":
# Use tokenization and set similarity for Chinese text
gold_tokens = self._tokenize_chinese(gold_text)
pred_tokens = self._tokenize_chinese(pred_text)
similarity = self._calculate_set_similarity(
gold_tokens, pred_tokens
)
if similarity >= similarity_threshold:
matches += 1
break # Break inner loop after finding match to avoid duplicate counting
else:
# Use original substring matching for 'all' mode
if gold_text in pred_text or pred_text in gold_text:
matches += 1
break # Break inner loop after finding match to avoid duplicate counting
return matches
def fuzzy_difference(
self, gold_set, prediction_set, fuzzy_mode="all", similarity_threshold=0.9
):
"""
Calculate difference using fuzzy matching, only matching text content and ignoring symbols
Returns the number of items in gold_set but not in prediction_set
Parameters:
gold_set: ground truth set
prediction_set: prediction result set
fuzzy_mode: text extraction mode ('all', 'chinese_only')
similarity_threshold: minimum similarity threshold for chinese_only mode (0.0 to 1.0)
Returns:
int: number of items in difference set
"""
unmatched = 0
for gold_item in gold_set:
gold_text = self._extract_text_only(gold_item, fuzzy_mode)
found_match = False
for pred_item in prediction_set:
pred_text = self._extract_text_only(pred_item, fuzzy_mode)
if fuzzy_mode == "chinese_only":
# Use tokenization and set similarity for Chinese text
gold_tokens = self._tokenize_chinese(gold_text)
pred_tokens = self._tokenize_chinese(pred_text)
similarity = self._calculate_set_similarity(
gold_tokens, pred_tokens
)
if similarity >= similarity_threshold:
found_match = True
break
else:
# Use original substring matching for 'all' mode
if gold_text in pred_text or pred_text in gold_text:
found_match = True
break
if not found_match:
unmatched += 1
return unmatched
def recall_top(
self,
predictionlist: list,
goldlist: List[str],
is_chunk_data: bool = True,
fuzzy_mode: str = "all",
similarity_threshold: float = 0.9,
):
"""
Calculate recall for top-3, top-5, and all predictions.
Parameters:
predictionlist (List[str]): List of predicted values from the model.
goldlist (List[str]): List of actual ground truth values.
is_chunk_data (bool): Whether predictionlist contains chunk data.
fuzzy_mode (str): Text extraction mode ('all', 'chinese_only').
similarity_threshold (float): Minimum similarity threshold for chinese_only mode (0.0 to 1.0).
Returns:
dict: Dictionary containing recall for top-3, top-5, and all predictions.
"""
predictionlist = self.convert_chunk_data_2_str(predictionlist)
if is_chunk_data:
predictionlist = self.convert_chunk_data_2_str(predictionlist)
# Split predictions into lists of top-3 and top-5
top3_predictions = predictionlist[:3]
top5_predictions = predictionlist[:5]
@ -91,8 +256,13 @@ class Evaluate:
top3_set = set(top3_predictions)
top5_set = set(top5_predictions)
true_positives_top3 = len(gold_set.intersection(top3_set))
false_negatives_top3 = len(gold_set - top3_set)
# Use fuzzy matching instead of exact intersection
true_positives_top3 = self.fuzzy_intersection(
gold_set, top3_set, fuzzy_mode, similarity_threshold
)
false_negatives_top3 = self.fuzzy_difference(
gold_set, top3_set, fuzzy_mode, similarity_threshold
)
recall_top3 = (
true_positives_top3 / (true_positives_top3 + false_negatives_top3)
@ -101,8 +271,12 @@ class Evaluate:
)
# Update counters for top-5
true_positives_top5 = len(gold_set.intersection(top5_set))
false_negatives_top5 = len(gold_set - top5_set)
true_positives_top5 = self.fuzzy_intersection(
gold_set, top5_set, fuzzy_mode, similarity_threshold
)
false_negatives_top5 = self.fuzzy_difference(
gold_set, top5_set, fuzzy_mode, similarity_threshold
)
recall_top5 = (
true_positives_top5 / (true_positives_top5 + false_negatives_top5)
@ -110,8 +284,12 @@ class Evaluate:
else 0.0
)
# Update counters for all
true_positives_all = len(gold_set.intersection(all_set))
false_negatives_all = len(gold_set - all_set)
true_positives_all = self.fuzzy_intersection(
gold_set, all_set, fuzzy_mode, similarity_threshold
)
false_negatives_all = self.fuzzy_difference(
gold_set, all_set, fuzzy_mode, similarity_threshold
)
recall_all = (
true_positives_all / (true_positives_all + false_negatives_all)

View File

@ -30,7 +30,7 @@ from kag.interface.solver.model.schema_utils import SchemaUtils
from kag.common.config import LogicFormConfiguration
from kag.common.tools.search_api.search_api_abc import SearchApiABC
from kag.common.tools.graph_api.graph_api_abc import GraphApiABC
from knext.schema.client import CHUNK_TYPE
from knext.schema.client import CHUNK_TYPE, TABLE_TYPE
logger = logging.getLogger()
chunk_cached_by_query_map = knext.common.cache.LinkCache(maxsize=100, ttl=300)
@ -83,13 +83,26 @@ class AtomicQueryChunkRetriever(RetrieverABC):
label=self.schema_helper.get_label_within_prefix(CHUNK_TYPE),
biz_id=doc_id,
)
if doc == {}:
doc = self.graph_api.get_entity_prop_by_id(
label=self.schema_helper.get_label_within_prefix(TABLE_TYPE),
biz_id=doc_id,
)
if doc == {}:
doc = self.graph_api.get_entity_prop_by_id(
label=self.schema_helper.get_label_within_prefix("Summary"),
biz_id=doc_id,
)
return ChunkData(
content=doc["content"],
title=doc["name"],
chunk_id=doc["id"],
score=atomic_query["score"],
)
if doc.get("content"):
return ChunkData(
content=doc["content"],
title=doc["name"],
chunk_id=doc["id"],
score=atomic_query["score"],
)
else:
return None
def parse_chosen_atom_infos(self, context: Context):
if not context:
@ -110,8 +123,13 @@ class AtomicQueryChunkRetriever(RetrieverABC):
with_json_parse=False,
)
rewritten_queries = rewritten_queries.append(query)
return rewritten_queries
rewritten_queries.append(query)
if rewritten_queries is not None:
return rewritten_queries
else:
rewritten_queries = []
rewritten_queries.append(query)
return rewritten_queries
async def recall_atomic_query(self, query: str, context: Context):
# rewrite query to expand diversity
@ -120,21 +138,30 @@ class AtomicQueryChunkRetriever(RetrieverABC):
rewritten_queries_vector_list = await self.vectorize_model.avectorize(
rewritten_queries
)
while rewritten_queries_vector_list is None:
rewritten_queries_vector_list = await self.vectorize_model.avectorize(
rewritten_queries
)
# recall atomic_query
tasks = []
for rewritten_queries_vector in rewritten_queries_vector_list:
task = asyncio.create_task(
asyncio.to_thread(
lambda: self.search_api.search_vector(
label=self.schema_helper.get_label_within_prefix("AtomicQuery"),
property_key="name",
query_vector=rewritten_queries_vector,
topk=self.top_k,
if rewritten_queries_vector_list is not None:
for rewritten_queries_vector in rewritten_queries_vector_list:
task = asyncio.create_task(
asyncio.to_thread(
lambda: self.search_api.search_vector(
label=self.schema_helper.get_label_within_prefix(
"AtomicQuery"
),
property_key="name",
query_vector=rewritten_queries_vector,
topk=self.top_k,
)
)
)
)
tasks.append(task)
tasks.append(task)
else:
tasks = []
top_k_atomic_queries = await asyncio.gather(*tasks)
top_k_atomic_queries_with_threshold = {}

View File

@ -137,6 +137,10 @@ class KAGRetrieverOutputMerger(RetrieverOutputMerger):
Returns:
RetrieverOutput: Final merged output containing unified chunk list.
"""
retrieved_chunks = kwargs.get("retrieved_chunks", None)
chunk_lists = [x.chunks for x in retrieve_outputs]
merged = self.chunk_merge(chunk_lists, self.rrf_normalize)
if retrieved_chunks is not None:
chunk_texts = [x.content for x in merged]
retrieved_chunks.extend(chunk_texts)
return RetrieverOutput(chunks=merged)

View File

@ -37,9 +37,20 @@ async def buildKB(dir_path):
)
def buildKB_debug(dir_path):
from kag.common.conf import KAG_CONFIG
runner = BuilderChainRunner.from_config(
KAG_CONFIG.all_config["kag_builder_pipeline"]
)
runner.invoke(dir_path)
if __name__ == "__main__":
dir_path = os.path.dirname(os.path.abspath(__file__))
import_modules_from_path(dir_path)
module_path = os.path.dirname(dir_path)
import_modules_from_path(module_path)
data_dir_path = os.path.join(dir_path, "data")
asyncio.run(buildKB(data_dir_path))
# buildKB_debug(data_dir_path)

View File

@ -36,18 +36,29 @@ class EvaForAFAC2024(EvalQa):
task_name="afac2024", solver_pipeline_name=solver_pipeline_name
)
async def qa(self, query, gold):
async def qa(self, query, gold, **kwargs):
reporter: TraceLogReporter = TraceLogReporter()
retrieved_chunks = []
pipeline = SolverPipelineABC.from_config(
KAG_CONFIG.all_config[self.solver_pipeline_name]
)
answer = await pipeline.ainvoke(query, reporter=reporter, gold=gold)
answer = await pipeline.ainvoke(
query,
reporter=reporter,
gold=gold,
retrieved_chunks=retrieved_chunks,
**kwargs,
)
logger.info(f"\n\nso the answer for '{query}' is: {answer}\n\n")
info, status = reporter.generate_report_data()
return answer, {"info": info.to_dict(), "status": status}
return answer, {
"info": info.to_dict(),
"status": status,
"retrieved_chunks": retrieved_chunks,
}
async def async_process_sample(self, data):
sample_idx, sample, ckpt = data
@ -58,10 +69,13 @@ class EvaForAFAC2024(EvalQa):
print(f"found existing answer to question: {question}")
prediction, trace_log = ckpt.read_from_ckpt(question)
else:
prediction, trace_log = await self.qa(query=question, gold=gold)
prediction, trace_log = await self.qa(
query=question, gold=gold, ckpt=ckpt
)
if ckpt:
ckpt.write_to_ckpt(question, (prediction, trace_log))
metrics = self.do_metrics_eval([question], [prediction], [gold])
metrics["recall"] = self.do_recall_eval(sample, [prediction], trace_log)
return sample_idx, prediction, metrics, trace_log
except Exception as e:
import traceback
@ -81,6 +95,19 @@ class EvaForAFAC2024(EvalQa):
eva_obj = Evaluate()
return eva_obj.getBenchMark(questionList, predictions, golds)
def do_recall_eval(self, sample, references, trace_log):
eva_obj = Evaluate()
predictions = trace_log.get("retrieved_chunks", None)
goldlist = []
for s in sample["supporting_facts"]:
goldlist.extend(s[1:])
return eva_obj.recall_top(
predictionlist=predictions,
goldlist=goldlist,
is_chunk_data=False,
fuzzy_mode="chinese_only",
)
if __name__ == "__main__":
import_modules_from_path("./src")

View File

@ -101,9 +101,22 @@ class EvalQa:
total_metrics = {
"processNum": len(metrics_list),
}
recall_metrics = {}
hit3 = 0.0
hit5 = 0.0
hitall = 0.0
for metric in metrics_list:
recall_data = metric["recall"]
hit3 += recall_data["recall_top3"]
hit5 += recall_data["recall_top5"]
hitall += recall_data["recall_all"]
recall_metrics["hit3"] = hit3 / len(metrics_list)
recall_metrics["hit5"] = hit5 / len(metrics_list)
recall_metrics["hitall"] = hitall / len(metrics_list)
if len(metrics_list) == 0:
return total_metrics
res_metrics = {}
res_metrics["recall"] = recall_metrics
for metric in metrics_list:
for k, v in metric.items():
if not isinstance(v, int) and not isinstance(v, float):

View File

@ -23,6 +23,7 @@ cache = knext.common.cache.SchemaCache()
CHUNK_TYPE = "Chunk"
TABLE_TYPE = "Table"
TITLE_TYPE = "Title"
OTHER_TYPE = "Others"
TEXT_TYPE = "Text"