2023-08-02 17:40:00 +08:00
|
|
|
import argparse
|
2023-08-03 11:16:51 +08:00
|
|
|
from collections import defaultdict
|
2023-08-04 16:22:35 +08:00
|
|
|
import os
|
|
|
|
import json
|
2023-08-02 17:40:00 +08:00
|
|
|
|
2023-08-03 11:16:51 +08:00
|
|
|
from mteb import MTEB
|
2023-08-04 16:22:35 +08:00
|
|
|
from C_MTEB import *
|
2023-08-02 17:40:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
def read_results(task_types, except_tasks, args):
|
|
|
|
tasks_results = {}
|
|
|
|
model_dirs = {}
|
|
|
|
for t_type in task_types:
|
|
|
|
tasks_results[t_type] = {}
|
2023-08-09 12:21:54 +08:00
|
|
|
for t in MTEB(task_types=[t_type], task_langs=args.lang).tasks:
|
2023-08-02 17:40:00 +08:00
|
|
|
task_name = t.description["name"]
|
|
|
|
if task_name in except_tasks: continue
|
|
|
|
|
|
|
|
metric = t.description["main_score"]
|
|
|
|
tasks_results[t_type][task_name] = defaultdict(None)
|
|
|
|
|
|
|
|
for model_name in os.listdir(args.results_dir):
|
|
|
|
model_dir = os.path.join(args.results_dir, model_name)
|
|
|
|
if not os.path.isdir(model_dir): continue
|
|
|
|
model_dirs[model_name] = model_dir
|
|
|
|
if os.path.exists(os.path.join(model_dir, task_name + '.json')):
|
|
|
|
data = json.load(open(os.path.join(model_dir, task_name + '.json')))
|
2023-08-09 12:21:54 +08:00
|
|
|
for s in ['test', 'dev', 'validation']:
|
2023-08-02 17:40:00 +08:00
|
|
|
if s in data:
|
|
|
|
split = s
|
|
|
|
break
|
|
|
|
|
|
|
|
if 'en-en' in data[split]:
|
|
|
|
temp_data = data[split]['en-en']
|
|
|
|
elif 'en' in data[split]:
|
|
|
|
temp_data = data[split]['en']
|
|
|
|
else:
|
|
|
|
temp_data = data[split]
|
|
|
|
|
|
|
|
if metric == 'ap':
|
|
|
|
tasks_results[t_type][task_name][model_name] = round(temp_data['cos_sim']['ap'] * 100, 2)
|
|
|
|
elif metric == 'cosine_spearman':
|
|
|
|
tasks_results[t_type][task_name][model_name] = round(temp_data['cos_sim']['spearman'] * 100, 2)
|
|
|
|
else:
|
|
|
|
tasks_results[t_type][task_name][model_name] = round(temp_data[metric] * 100, 2)
|
|
|
|
|
|
|
|
return tasks_results, model_dirs
|
|
|
|
|
|
|
|
|
|
|
|
def output_markdown(tasks_results, model_names, save_file):
|
|
|
|
task_type_res = {}
|
|
|
|
with open(save_file, 'w') as f:
|
|
|
|
for t_type, type_results in tasks_results.items():
|
2023-08-04 16:22:35 +08:00
|
|
|
has_CQADupstack = False
|
|
|
|
task_cnt = 0
|
2023-08-02 17:40:00 +08:00
|
|
|
task_type_res[t_type] = defaultdict()
|
|
|
|
f.write(f'Task Type: {t_type} \n')
|
|
|
|
first_line = "| Model |"
|
|
|
|
second_line = "|:-------------------------------|"
|
|
|
|
for task_name in type_results.keys():
|
2023-08-04 16:22:35 +08:00
|
|
|
if "CQADupstack" in task_name:
|
|
|
|
has_CQADupstack = True
|
|
|
|
continue
|
2023-08-02 17:40:00 +08:00
|
|
|
first_line += f" {task_name} |"
|
|
|
|
second_line += ":--------:|"
|
2023-08-04 16:22:35 +08:00
|
|
|
task_cnt += 1
|
|
|
|
if has_CQADupstack:
|
|
|
|
first_line += f" CQADupstack |"
|
|
|
|
second_line += ":--------:|"
|
|
|
|
task_cnt += 1
|
|
|
|
f.write(first_line+' Avg | \n')
|
|
|
|
f.write(second_line+':--------:| \n')
|
2023-08-02 17:40:00 +08:00
|
|
|
|
|
|
|
for model in model_names:
|
|
|
|
write_line = f"| {model} |"
|
|
|
|
all_res = []
|
2023-08-04 16:22:35 +08:00
|
|
|
cqa_res = []
|
2023-08-02 17:40:00 +08:00
|
|
|
for task_name, results in type_results.items():
|
2023-08-04 16:22:35 +08:00
|
|
|
if "CQADupstack" in task_name:
|
|
|
|
if model in results:
|
|
|
|
cqa_res.append(results[model])
|
|
|
|
continue
|
|
|
|
|
2023-08-02 17:40:00 +08:00
|
|
|
if model in results:
|
|
|
|
write_line += f" {results[model]} |"
|
|
|
|
all_res.append(results[model])
|
|
|
|
else:
|
|
|
|
write_line += f" |"
|
|
|
|
|
2023-08-04 16:22:35 +08:00
|
|
|
if len(cqa_res) > 0:
|
|
|
|
write_line += f" {round(sum(cqa_res)/len(cqa_res), 2)} |"
|
|
|
|
all_res.append(round(sum(cqa_res)/len(cqa_res), 2))
|
|
|
|
|
|
|
|
# if len(all_res) == len(type_results.keys()):
|
|
|
|
if len(all_res) == task_cnt:
|
|
|
|
write_line += f" {round(sum(all_res)/len(all_res), 2)} |"
|
2023-08-02 17:40:00 +08:00
|
|
|
task_type_res[t_type][model] = all_res
|
|
|
|
else:
|
|
|
|
write_line += f" |"
|
2023-08-04 16:22:35 +08:00
|
|
|
f.write(write_line+' \n')
|
|
|
|
|
2023-08-02 17:40:00 +08:00
|
|
|
|
|
|
|
f.write(f'Overall \n')
|
|
|
|
first_line = "| Model |"
|
|
|
|
second_line = "|:-------------------------------|"
|
|
|
|
for t_type in task_type_res.keys():
|
|
|
|
first_line += f" {t_type} |"
|
|
|
|
second_line += ":--------:|"
|
|
|
|
f.write(first_line + ' Avg | \n')
|
|
|
|
f.write(second_line + ':--------:| \n')
|
|
|
|
|
|
|
|
for model in model_names:
|
|
|
|
write_line = f"| {model} |"
|
|
|
|
all_res = []
|
|
|
|
for type_name, results in task_type_res.items():
|
|
|
|
if model in results:
|
2023-08-04 16:22:35 +08:00
|
|
|
write_line += f" {round(sum(results[model])/len(results[model]), 2)} |"
|
2023-08-02 17:40:00 +08:00
|
|
|
all_res.extend(results[model])
|
|
|
|
else:
|
|
|
|
write_line += f" |"
|
|
|
|
|
|
|
|
if len(all_res) > 0:
|
|
|
|
write_line += f" {round(sum(all_res) / len(all_res), 2)} |"
|
|
|
|
|
|
|
|
f.write(write_line + ' \n')
|
|
|
|
|
|
|
|
|
2023-08-04 16:22:35 +08:00
|
|
|
|
|
|
|
|
2023-08-02 17:40:00 +08:00
|
|
|
def get_args():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--results_dir', default="./zh_results", type=str)
|
|
|
|
parser.add_argument('--lang', default="zh", type=str)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = get_args()
|
|
|
|
|
|
|
|
if args.lang == 'zh':
|
|
|
|
task_types = ["Retrieval", "STS", "PairClassification", "Classification", "Reranking", "Clustering"]
|
2023-08-09 12:21:54 +08:00
|
|
|
except_tasks = []
|
|
|
|
args.lang = ['zh', 'zh-CN']
|
2023-08-02 17:40:00 +08:00
|
|
|
elif args.lang == 'en':
|
2023-08-04 16:22:35 +08:00
|
|
|
task_types = ["Retrieval", "Clustering", "PairClassification", "Reranking", "STS", "Summarization", "Classification" ]
|
|
|
|
except_tasks = ['MSMARCOv2']
|
2023-08-09 12:21:54 +08:00
|
|
|
args.lang = ['en']
|
2023-08-02 17:40:00 +08:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"args.lang must be zh or en, but{args.lang}")
|
|
|
|
|
|
|
|
task_results, model_dirs = read_results(task_types, except_tasks, args=args)
|
2023-08-04 16:22:35 +08:00
|
|
|
|
2023-08-09 12:21:54 +08:00
|
|
|
output_markdown(task_results, model_dirs.keys(), save_file=os.path.join(args.results_dir, f'{args.lang[0]}_results.md'))
|
2023-08-04 16:22:35 +08:00
|
|
|
|
|
|
|
|