FlagEmbedding/research/C_MTEB/summarize_results.py

144 lines
5.2 KiB
Python
Raw Permalink Normal View History

2023-08-02 17:40:00 +08:00
import argparse
2023-08-04 16:22:35 +08:00
import json
2023-09-12 19:55:37 +08:00
import os
from collections import defaultdict
2023-08-02 17:40:00 +08:00
2023-08-04 16:22:35 +08:00
from C_MTEB import *
import mteb
2023-09-12 19:55:37 +08:00
from mteb import MTEB
2023-08-02 17:40:00 +08:00
CMTEB_tasks = [
'TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
'CLSClusteringS2S.v2', 'CLSClusteringP2P.v2', 'ThuNewsClusteringS2S.v2', 'ThuNewsClusteringP2P.v2',
'Ocnli', 'Cmnli',
'T2Reranking', 'MMarcoReranking', 'CMedQAv1-reranking', 'CMedQAv2-reranking',
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC'
]
def read_results(task_types, args):
2023-08-02 17:40:00 +08:00
tasks_results = {}
# model_dirs = {}
2023-08-02 17:40:00 +08:00
for t_type in task_types:
tasks_results[t_type] = {}
for t in mteb.get_tasks(task_types=[t_type]):
task_name = t.metadata.name
if task_name not in CMTEB_tasks:
continue
2023-08-02 17:40:00 +08:00
metric = t.metadata.main_score
2023-08-02 17:40:00 +08:00
tasks_results[t_type][task_name] = defaultdict(None)
if os.path.exists(os.path.join(args.results_dir, task_name + '.json')):
data = json.load(open(os.path.join(args.results_dir, task_name + '.json')))
for s in ['test', 'dev', 'validation']:
if s in data['scores']:
split = s
break
temp_data = data['scores'][split][0]
tasks_results[t_type][task_name] = round(temp_data[metric] * 100, 2)
return tasks_results
def output_markdown(tasks_results, model, save_file):
2023-08-02 17:40:00 +08:00
task_type_res = {}
with open(save_file, 'w') as f:
for t_type, type_results in tasks_results.items():
2023-08-04 16:22:35 +08:00
has_CQADupstack = False
task_cnt = 0
2023-08-02 17:40:00 +08:00
task_type_res[t_type] = defaultdict()
f.write(f'Task Type: {t_type} \n')
first_line = "| Model |"
second_line = "|:-------------------------------|"
for task_name in type_results.keys():
2023-08-04 16:22:35 +08:00
if "CQADupstack" in task_name:
has_CQADupstack = True
continue
2023-08-02 17:40:00 +08:00
first_line += f" {task_name} |"
second_line += ":--------:|"
2023-08-04 16:22:35 +08:00
task_cnt += 1
if has_CQADupstack:
first_line += f" CQADupstack |"
second_line += ":--------:|"
task_cnt += 1
2023-08-09 15:13:03 +08:00
f.write(first_line + ' Avg | \n')
f.write(second_line + ':--------:| \n')
2023-08-02 17:40:00 +08:00
write_line = f"| {model} |"
all_res = []
cqa_res = []
for task_name, results in type_results.items():
if "CQADupstack" in task_name:
2023-08-02 17:40:00 +08:00
if model in results:
cqa_res.append(results[model])
continue
write_line += f" {results} |"
all_res.append(results)
if len(cqa_res) > 0:
write_line += f" {round(sum(cqa_res) / len(cqa_res), 2)} |"
all_res.append(round(sum(cqa_res) / len(cqa_res), 2))
# if len(all_res) == len(type_results.keys()):
if len(all_res) == task_cnt:
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
task_type_res[t_type][model] = all_res
else:
write_line += f" |"
f.write(write_line + ' \n\n')
2023-08-02 17:40:00 +08:00
f.write(f'Overall \n')
first_line = "| Model |"
second_line = "|:-------------------------------|"
for t_type in task_type_res.keys():
first_line += f" {t_type} |"
second_line += ":--------:|"
f.write(first_line + ' Avg | \n')
f.write(second_line + ':--------:| \n')
write_line = f"| {model} |"
all_res = []
for type_name, results in task_type_res.items():
if model in results:
write_line += f" {round(sum(results[model]) / len(results[model]), 2)} |"
all_res.extend(results[model])
else:
write_line += f" |"
if len(all_res) > 0:
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
2023-08-02 17:40:00 +08:00
f.write(write_line + ' \n')
2023-08-02 17:40:00 +08:00
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--results_dir', default="./zh_results", type=str)
parser.add_argument('--lang', default="zh", type=str)
parser.add_argument('--model', default="model", type=str)
2023-08-02 17:40:00 +08:00
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
if args.lang == 'zho':
2023-08-02 17:40:00 +08:00
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
args.lang = ['zho']
elif args.lang == 'eng':
2023-08-09 15:13:03 +08:00
task_types = ["Retrieval", "Clustering", "PairClassification", "Reranking", "STS", "Summarization",
"Classification"]
args.lang = ['eng']
2023-08-02 17:40:00 +08:00
else:
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
task_results = read_results(task_types, args=args)
2023-08-04 16:22:35 +08:00
output_markdown(task_results, args.model,
2023-08-09 15:13:03 +08:00
save_file=os.path.join(args.results_dir, f'{args.lang[0]}_results.md'))