mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-12-30 16:52:05 +00:00
eval beir
This commit is contained in:
parent
6d71e83cf2
commit
00a42ccd4f
@ -0,0 +1,96 @@
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
|
||||
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker
|
||||
|
||||
|
||||
from utils.arguments import BEIREvalArgs
|
||||
from utils.data_loader import BEIRDataLoader
|
||||
from utils.evaluator import BEIREvaluator
|
||||
|
||||
|
||||
def get_models(model_args: AbsModelArgs):
|
||||
retriever = FlagAutoModel.from_finetuned(
|
||||
model_name_or_path=model_args.embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
reranker = None
|
||||
if model_args.reranker_name_or_path is not None:
|
||||
reranker = FlagAutoReranker.from_finetuned(
|
||||
model_name_or_path=model_args.reranker_name_or_path,
|
||||
peft_path=model_args.reranker_peft_path,
|
||||
use_fp16=model_args.use_fp16,
|
||||
use_bf16=model_args.use_bf16,
|
||||
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
|
||||
query_instruction_format=model_args.query_instruction_format_for_rerank,
|
||||
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
|
||||
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
devices=model_args.devices,
|
||||
normalize=model_args.normalize,
|
||||
prompt=model_args.prompt,
|
||||
cutoff_layers=model_args.cutoff_layers,
|
||||
compress_layers=model_args.compress_layers,
|
||||
compress_ratio=model_args.compress_ratio,
|
||||
)
|
||||
return retriever, reranker
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser([AbsModelArgs, BEIREvalArgs])
|
||||
model_args, eval_args = parser.parse_args_into_dataclasses()
|
||||
model_args: AbsModelArgs
|
||||
eval_args: BEIREvalArgs
|
||||
|
||||
retriever, reranker = get_models(model_args)
|
||||
retriever = AbsEmbedder(
|
||||
retriever,
|
||||
search_top_k=eval_args.search_top_k,
|
||||
)
|
||||
|
||||
if reranker is not None:
|
||||
reranker = AbsReranker(
|
||||
reranker,
|
||||
rerank_top_k=eval_args.rerank_top_k,
|
||||
)
|
||||
else:
|
||||
reranker = None
|
||||
|
||||
for dataset_name in eval_args.dataset_names:
|
||||
|
||||
data_loader = BEIRDataLoader(
|
||||
dataset_dir = eval_args.dataset_dir,
|
||||
cache_dir = eval_args.cache_path,
|
||||
dataset_name=dataset_name
|
||||
)
|
||||
|
||||
evaluation = BEIREvaluator(
|
||||
data_loader=data_loader,
|
||||
overwrite=eval_args.overwrite,
|
||||
)
|
||||
|
||||
evaluation(
|
||||
splits=eval_args.splits.split(),
|
||||
search_results_save_dir=eval_args.output_dir,
|
||||
retriever=retriever,
|
||||
reranker=reranker,
|
||||
corpus_embd_save_dir=eval_args.corpus_embd_save_dir,
|
||||
retriever_batch_size=model_args.retriever_batch_size,
|
||||
reranker_batch_size=model_args.reranker_batch_size,
|
||||
retriever_query_max_length=model_args.retriever_query_max_length,
|
||||
retriever_passage_max_length=model_args.retriever_passage_max_length,
|
||||
reranker_max_length=model_args.reranker_max_length,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
FlagEmbedding/evaluation/beir/run.sh
Normal file
12
FlagEmbedding/evaluation/beir/run.sh
Normal file
@ -0,0 +1,12 @@
|
||||
python __main__.py \
|
||||
--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR \
|
||||
--embedder_name_or_path BAAI/bge-large-en-v1.5 \
|
||||
--reranker_name_or_path BAAI/bge-reranker-large \
|
||||
--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: " \
|
||||
--use_fp16 True \
|
||||
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
|
||||
--cache_dir /share/shared_models \
|
||||
--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR_passage_embds \
|
||||
--reranker_max_length 512 \
|
||||
--dataset_names nfcorpus fiqa cqadupstack
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from FlagEmbedding.abc.evaluation.data_loader import AbsDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MSMARCODataLoader(AbsDataLoader):
|
||||
class BEIRDataLoader(AbsDataLoader):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir: str, # the dataset dir to load from
|
||||
@ -32,16 +32,15 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
self.split = 'test'
|
||||
if dataset_name == 'msmarco': self.split = 'dev'
|
||||
|
||||
os.makedirs(self.dataset_dir, exist_ok=True)
|
||||
|
||||
if dataset_name != 'cqadupstack':
|
||||
os.makedirs(self.dataset_dir, exist_ok=True)
|
||||
qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split))
|
||||
corpus_path = os.path.join(self.dataset_dir, 'corpus.json')
|
||||
queries_path = os.path.join(self.dataset_dir, 'queries-{split}.json'.format(split=self.split))
|
||||
queries, corpus, rels = {}, {}, {}
|
||||
if not os.path.exists(corpus_path):
|
||||
dataset = datasets.load_dataset(
|
||||
'BeIR/{d}'.format(d=data_name),
|
||||
'BeIR/{d}'.format(d=dataset_name),
|
||||
'corpus',
|
||||
trust_remote_code=True,
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', cache_dir),
|
||||
@ -55,7 +54,7 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
|
||||
if not os.path.exists(queries_path):
|
||||
dataset = datasets.load_dataset(
|
||||
'BeIR/{d}'.format(d=data_name),
|
||||
'BeIR/{d}'.format(d=dataset_name),
|
||||
'queries',
|
||||
trust_remote_code=True,
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', cache_dir),
|
||||
@ -69,7 +68,7 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
|
||||
if not os.path.exists(qrels_path):
|
||||
dataset = datasets.load_dataset(
|
||||
'BeIR/{d}-qrels'.format(d=data_name),
|
||||
'BeIR/{d}-qrels'.format(d=dataset_name),
|
||||
split=self.split,
|
||||
trust_remote_code=True,
|
||||
cache_dir=os.getenv('HF_HUB_CACHE', cache_dir),
|
||||
@ -93,10 +92,8 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
if not os.path.exists(corpus_path) or not os.path.exists(queries_path) or not os.path.exists(qrels_path):
|
||||
full_path = os.path.join(data_path, sub_dataset_name)
|
||||
corpus, queries, qrels = GenericDataLoader(data_folder=full_path).load(split="test")
|
||||
for k in corpus_data.keys():
|
||||
corpus[k] = {
|
||||
(corpus[k]['title'] + ' ' + corpus[k]['text']).strip()
|
||||
}
|
||||
for k in corpus.keys():
|
||||
corpus[k] = (corpus[k]['title'] + ' ' + corpus[k]['text']).strip()
|
||||
|
||||
with open(corpus_path, 'w') as f:
|
||||
json.dump(corpus, f)
|
||||
@ -105,25 +102,40 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
json.dump(queries, f)
|
||||
|
||||
with open(qrels_path, 'w') as f:
|
||||
json.dump(rels, f)
|
||||
json.dump(qrels, f)
|
||||
|
||||
|
||||
def load_qrels(self):
|
||||
if self.sub_dataset_name is None:
|
||||
def load_qrels(self, sub_dataset_name: str = None, split: str = None):
|
||||
if sub_dataset_name is None and split is not None:
|
||||
sub_dataset_name = split.split('-')
|
||||
if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0:
|
||||
sub_dataset_name = None
|
||||
else:
|
||||
sub_dataset_name = sub_dataset_name[0].strip()
|
||||
if sub_dataset_name is None:
|
||||
qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split))
|
||||
rels = json.load(open(qrels_path))
|
||||
return datasets.DatasetDict(rels)
|
||||
else:
|
||||
rels_list = []
|
||||
for sub_dataset_name in self.sub_dataset_names:
|
||||
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
|
||||
rels = json.load(open(qrels_path))
|
||||
rels_list.append(datasets.DatasetDict(rels))
|
||||
return rels_list
|
||||
# rels_list = []
|
||||
# for sub_dataset_name in self.sub_dataset_names:
|
||||
# qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
|
||||
# rels = json.load(open(qrels_path))
|
||||
# rels_list.append(datasets.DatasetDict(rels))
|
||||
# return rels_list
|
||||
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
|
||||
rels = json.load(open(qrels_path))
|
||||
return datasets.DatasetDict(rels)
|
||||
|
||||
|
||||
def load_corpus(self):
|
||||
if self.sub_dataset_name is None:
|
||||
def load_corpus(self, sub_dataset_name: str = None, split: str = None):
|
||||
if sub_dataset_name is None and split is not None:
|
||||
sub_dataset_name = split.split('-')
|
||||
if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0:
|
||||
sub_dataset_name = None
|
||||
else:
|
||||
sub_dataset_name = sub_dataset_name[0].strip()
|
||||
if sub_dataset_name is None:
|
||||
corpus_path = os.path.join(self.dataset_dir, 'corpus.json')
|
||||
corpus = json.load(open(corpus_path))
|
||||
for k in corpus.keys():
|
||||
@ -132,15 +144,28 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
}
|
||||
return datasets.DatasetDict(corpus)
|
||||
else:
|
||||
corpus_list = []
|
||||
for sub_dataset_name in self.sub_dataset_names:
|
||||
corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json')
|
||||
corpus = json.load(open(corpus_path))
|
||||
corpus_list.append(datasets.DatasetDict(corpus))
|
||||
return corpus_list
|
||||
corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json')
|
||||
corpus = json.load(open(corpus_path))
|
||||
for k in corpus.keys():
|
||||
corpus[k] = {
|
||||
'text': corpus[k]
|
||||
}
|
||||
return datasets.DatasetDict(corpus)
|
||||
# corpus_list = []
|
||||
# for sub_dataset_name in self.sub_dataset_names:
|
||||
# corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json')
|
||||
# corpus = json.load(open(corpus_path))
|
||||
# corpus_list.append(datasets.DatasetDict(corpus))
|
||||
# return corpus_list
|
||||
|
||||
def load_queries(self, split='dev'):
|
||||
if self.sub_dataset_name is None:
|
||||
def load_queries(self, sub_dataset_name: str = None, split: str = None):
|
||||
if sub_dataset_name is None and split is not None:
|
||||
sub_dataset_name = split.split('-')
|
||||
if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0:
|
||||
sub_dataset_name = None
|
||||
else:
|
||||
sub_dataset_name = sub_dataset_name[0].strip()
|
||||
if sub_dataset_name is None:
|
||||
queries_path = os.path.join(self.dataset_dir, 'queries-{split}.json'.format(split=self.split))
|
||||
qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split))
|
||||
queries = json.load(open(queries_path))
|
||||
@ -151,15 +176,25 @@ class MSMARCODataLoader(AbsDataLoader):
|
||||
new_queries[k] = queries[k]
|
||||
return datasets.DatasetDict(new_queries)
|
||||
else:
|
||||
queries_list = []
|
||||
for sub_dataset_name in self.sub_dataset_names:
|
||||
queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split))
|
||||
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
|
||||
queries = json.load(open(queries_path))
|
||||
rels = json.load(open(qrels_path))
|
||||
new_queries = {}
|
||||
for k in queries.keys():
|
||||
if k in rels.keys():
|
||||
new_queries[k] = queries[k]
|
||||
queries_list.append(datasets.DatasetDict(new_queries))
|
||||
return queries_list
|
||||
queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split))
|
||||
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
|
||||
queries = json.load(open(queries_path))
|
||||
print(qrels_path)
|
||||
rels = json.load(open(qrels_path))
|
||||
new_queries = {}
|
||||
for k in queries.keys():
|
||||
if k in rels.keys():
|
||||
new_queries[k] = queries[k]
|
||||
return datasets.DatasetDict(new_queries)
|
||||
# queries_list = []
|
||||
# for sub_dataset_name in self.sub_dataset_names:
|
||||
# queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split))
|
||||
# qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
|
||||
# queries = json.load(open(queries_path))
|
||||
# rels = json.load(open(qrels_path))
|
||||
# new_queries = {}
|
||||
# for k in queries.keys():
|
||||
# if k in rels.keys():
|
||||
# new_queries[k] = queries[k]
|
||||
# queries_list.append(datasets.DatasetDict(new_queries))
|
||||
# return queries_list
|
||||
|
||||
@ -6,6 +6,8 @@ import pandas as pd
|
||||
from typing import Dict, Optional, List, Union
|
||||
|
||||
from FlagEmbedding.abc.evaluation.evaluator import AbsEvaluator
|
||||
from FlagEmbedding.abc.evaluation.data_loader import AbsDataLoader
|
||||
from FlagEmbedding.abc.evaluation.searcher import AbsEmbedder, AbsReranker
|
||||
|
||||
class BEIREvaluator(AbsEvaluator):
|
||||
def __init__(
|
||||
@ -33,50 +35,55 @@ class BEIREvaluator(AbsEvaluator):
|
||||
):
|
||||
dataset_name = self.data_loader.dataset_name
|
||||
sub_dataset_names = self.data_loader.sub_dataset_names
|
||||
if sub_dataset_name
|
||||
split = self.data_loader.split
|
||||
if isinstance(splits, str):
|
||||
splits = [splits]
|
||||
# Retrieval Stage
|
||||
no_reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), "NoReranker"
|
||||
search_results_save_dir, str(retriever), "NoReranker", dataset_name
|
||||
)
|
||||
os.makedirs(no_reranker_search_results_save_dir, exist_ok=True)
|
||||
if corpus_embd_save_dir is not None: corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, dataset_name)
|
||||
|
||||
flag = False
|
||||
if sub_dataset_names is None:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, f"{dataset_name}.json"
|
||||
no_reranker_search_results_save_dir, f"{split}.json"
|
||||
)
|
||||
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
|
||||
flag = True
|
||||
break
|
||||
else:
|
||||
for sub_dataset_name
|
||||
for sub_dataset_name in sub_dataset_names:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
|
||||
)
|
||||
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
|
||||
flag = True
|
||||
break
|
||||
|
||||
no_reranker_search_results_dict = {}
|
||||
if flag:
|
||||
corpus = self.data_loader.load_corpus()
|
||||
if sub_dataset_names is None:
|
||||
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_names)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(split=split)
|
||||
for split in splits
|
||||
}
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_names)
|
||||
}
|
||||
|
||||
all_queries = {}
|
||||
for _, split_queries in queries_dict.items():
|
||||
all_queries.update(split_queries)
|
||||
all_queries = {}
|
||||
for _, split_queries in queries_dict.items():
|
||||
all_queries.update(split_queries)
|
||||
|
||||
all_no_reranker_search_results = retriever(
|
||||
corpus=corpus,
|
||||
queries=all_queries,
|
||||
corpus_embd_save_dir=corpus_embd_save_dir,
|
||||
batch_size=retriever_batch_size,
|
||||
query_max_length=retriever_query_max_length,
|
||||
passage_max_length=retriever_passage_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
all_no_reranker_search_results = retriever(
|
||||
corpus=corpus,
|
||||
queries=all_queries,
|
||||
corpus_embd_save_dir=corpus_embd_save_dir,
|
||||
batch_size=retriever_batch_size,
|
||||
query_max_length=retriever_query_max_length,
|
||||
passage_max_length=retriever_passage_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for split in splits:
|
||||
split_queries = queries_dict[split]
|
||||
no_reranker_search_results_dict[split] = {
|
||||
qid: all_no_reranker_search_results[qid] for qid in split_queries
|
||||
@ -93,8 +100,47 @@ class BEIREvaluator(AbsEvaluator):
|
||||
split=split,
|
||||
dataset_dir=self.dataset_dir,
|
||||
)
|
||||
else:
|
||||
for sub_dataset_name in sub_dataset_names:
|
||||
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_name)
|
||||
}
|
||||
|
||||
all_queries = {}
|
||||
for _, split_queries in queries_dict.items():
|
||||
all_queries.update(split_queries)
|
||||
|
||||
all_no_reranker_search_results = retriever(
|
||||
corpus=corpus,
|
||||
queries=all_queries,
|
||||
corpus_embd_save_dir=None if corpus_embd_save_dir is None else os.path.join(corpus_embd_save_dir, sub_dataset_name),
|
||||
batch_size=retriever_batch_size,
|
||||
query_max_length=retriever_query_max_length,
|
||||
passage_max_length=retriever_passage_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
split_queries = queries_dict[split]
|
||||
no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"] = {
|
||||
qid: all_no_reranker_search_results[qid] for qid in split_queries
|
||||
}
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
search_results=no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"],
|
||||
output_path=split_no_reranker_search_results_save_path,
|
||||
split=f"{sub_dataset_name}-{split}",
|
||||
dataset_dir=self.dataset_dir,
|
||||
)
|
||||
|
||||
else:
|
||||
for split in splits:
|
||||
if sub_dataset_names is None:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, f"{split}.json"
|
||||
)
|
||||
@ -107,47 +153,93 @@ class BEIREvaluator(AbsEvaluator):
|
||||
split=split,
|
||||
)
|
||||
no_reranker_search_results_dict[split] = search_results
|
||||
else:
|
||||
for sub_dataset_name in sub_dataset_names:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
|
||||
)
|
||||
data_info, search_results = self.load_search_results(split_no_reranker_search_results_save_path)
|
||||
|
||||
self.check_data_info(
|
||||
data_info=data_info,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
split=f"{sub_dataset_name}-{split}",
|
||||
)
|
||||
no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"] = search_results
|
||||
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir)
|
||||
self.output_eval_results_to_json(retriever_eval_results, os.path.join(no_reranker_search_results_save_dir, 'eval.json'))
|
||||
|
||||
# Reranking Stage
|
||||
if reranker is not None:
|
||||
reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), str(reranker)
|
||||
search_results_save_dir, str(retriever), str(reranker), dataset_name
|
||||
)
|
||||
os.makedirs(reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
corpus = self.data_loader.load_corpus()
|
||||
if sub_dataset_names is None:
|
||||
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_names)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(split=split)
|
||||
for split in splits
|
||||
}
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_names)
|
||||
}
|
||||
|
||||
for split in splits:
|
||||
rerank_search_results_save_path = os.path.join(
|
||||
reranker_search_results_save_dir, f"{split}.json"
|
||||
)
|
||||
|
||||
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
|
||||
continue
|
||||
pass
|
||||
else:
|
||||
rerank_search_results = reranker(
|
||||
corpus=corpus,
|
||||
queries=queries_dict[split],
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
batch_size=reranker_batch_size,
|
||||
max_length=reranker_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rerank_search_results = reranker(
|
||||
corpus=corpus,
|
||||
queries=queries_dict[split],
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
batch_size=reranker_batch_size,
|
||||
max_length=reranker_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
self.save_search_results(
|
||||
model_name=str(retriever),
|
||||
reranker_name=str(reranker),
|
||||
search_results=rerank_search_results,
|
||||
output_path=rerank_search_results_save_path,
|
||||
split=split,
|
||||
dataset_dir=self.dataset_dir,
|
||||
)
|
||||
else:
|
||||
for sub_dataset_name in sub_dataset_names:
|
||||
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_name)
|
||||
}
|
||||
|
||||
rerank_search_results_save_path = os.path.join(
|
||||
reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
|
||||
)
|
||||
|
||||
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
|
||||
continue
|
||||
|
||||
rerank_search_results = reranker(
|
||||
corpus=corpus,
|
||||
queries=queries_dict[split],
|
||||
search_results=no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"],
|
||||
batch_size=reranker_batch_size,
|
||||
max_length=reranker_max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
model_name=str(retriever),
|
||||
reranker_name=str(reranker),
|
||||
search_results=rerank_search_results,
|
||||
output_path=rerank_search_results_save_path,
|
||||
split=f"{sub_dataset_name}-{split}",
|
||||
dataset_dir=self.dataset_dir,
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
model_name=str(retriever),
|
||||
reranker_name=str(reranker),
|
||||
search_results=rerank_search_results,
|
||||
output_path=rerank_search_results_save_path,
|
||||
split=split,
|
||||
dataset_dir=self.dataset_dir,
|
||||
)
|
||||
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir)
|
||||
self.output_eval_results_to_json(reranker_eval_results, os.path.join(reranker_search_results_save_dir, 'eval.json'))
|
||||
@ -0,0 +1,44 @@
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
|
||||
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
|
||||
|
||||
|
||||
from utils.arguments import MSMARCOEvalArgs
|
||||
from utils.data_loader import MSMARCODataLoader
|
||||
|
||||
|
||||
def get_models(model_args: AbsModelArgs):
|
||||
retriever = FlagAutoModel.from_finetuned(
|
||||
model_name_or_path=model_args.embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
reranker = None
|
||||
if model_args.reranker_name_or_path is not None:
|
||||
reranker = FlagAutoReranker.from_finetuned(
|
||||
model_name_or_path=model_args.reranker_name_or_path,
|
||||
peft_path=model_args.reranker_peft_path,
|
||||
use_fp16=model_args.use_fp16,
|
||||
use_bf16=model_args.use_bf16,
|
||||
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
|
||||
query_instruction_format=model_args.query_instruction_format_for_rerank,
|
||||
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
|
||||
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
devices=model_args.devices,
|
||||
normalize=model_args.normalize,
|
||||
prompt=model_args.prompt,
|
||||
cutoff_layers=model_args.cutoff_layers,
|
||||
compress_layers=model_args.compress_layers,
|
||||
compress_ratio=model_args.compress_ratio,
|
||||
)
|
||||
return retriever, reranker
|
||||
@ -1,81 +1,329 @@
|
||||
from transformers import HfArgumentParser
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from FlagEmbedding import AutoFlagModel
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
from utils.arguments import ModelArgs
|
||||
from utils.models import SentenceTransformerEncoder, SentenceTransformerReranker
|
||||
from utils.searcher import EmbeddingModelRetriever, CrossEncoderReranker
|
||||
import mteb
|
||||
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from typing import List
|
||||
from mteb import MTEB
|
||||
|
||||
from utils import logger, pool, move_to_cuda, get_detailed_instruct, get_task_def_by_task_name_and_type, create_batch_dict, tasks_desc, create_batch_query_dict
|
||||
|
||||
parser = argparse.ArgumentParser(description='evaluation for MTEB benchmark except its Retrieval category')
|
||||
parser.add_argument('--task-types', nargs='+', default=[], help='task types to evaluate')
|
||||
parser.add_argument('--output-dir', default='',
|
||||
type=str, metavar='N', help='output directory')
|
||||
parser.add_argument('--model-name-or-path', default='tmp-outputs/',
|
||||
type=str, metavar='N', help='which model to use')
|
||||
parser.add_argument('--peft-name-or-path', default=None, type=str)
|
||||
parser.add_argument('--tokenizer-path', default=None)
|
||||
parser.add_argument('--embedding-path', default=None)
|
||||
parser.add_argument('--special-token', default=False, type=bool, help='whether to use special token')
|
||||
parser.add_argument('--zero-shot', default=False, type=bool, help='whether to use zero shot icl')
|
||||
parser.add_argument('--device', default=0, type=int)
|
||||
parser.add_argument('--all-num', default=8, type=int)
|
||||
parser.add_argument('--batch_size', default=32, type=int)
|
||||
parser.add_argument('--examples-dir', default='/share/chaofan/code/embedder/evaluate_for_icl/examples', type=str)
|
||||
parser.add_argument('--eight-special-token', default=False, type=bool)
|
||||
parser.add_argument('--passage-prompt', default=False, type=bool)
|
||||
|
||||
args = parser.parse_args()
|
||||
base_name: str = args.model_name_or_path.split('/')[-1]
|
||||
if args.eight_special_token is True:
|
||||
args.pool_type = 'last_eight'
|
||||
else:
|
||||
args.pool_type = 'last'
|
||||
|
||||
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
|
||||
assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last'
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ALL_NUM = args.all_num
|
||||
|
||||
print(args.special_token)
|
||||
class DenseEncoder(torch.nn.Module):
|
||||
def __init__(self, device, **kwargs):
|
||||
super().__init__()
|
||||
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, use_cache=False)
|
||||
if args.peft_name_or_path is not None:
|
||||
peft_name_or_path = args.peft_name_or_path.split(',')
|
||||
if args.embedding_path is not None:
|
||||
self.encoder.set_input_embeddings(torch.load(args.embedding_path))
|
||||
elif args.special_token and os.path.exists(os.path.join(peft_name_or_path[-1], 'embedding', 'emb.pth')):
|
||||
self.encoder.set_input_embeddings(torch.load(os.path.join(peft_name_or_path[-1], 'embedding', 'emb.pth')))
|
||||
for peft_path in peft_name_or_path:
|
||||
self.encoder = PeftModel.from_pretrained(self.encoder, peft_path)
|
||||
self.encoder = self.encoder.merge_and_unload()
|
||||
|
||||
if args.tokenizer_path is not None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
self.l2_normalize = True
|
||||
self.prompt = None
|
||||
self.prefix = ''
|
||||
self.suffix = ''
|
||||
self.gpu_count = torch.cuda.device_count()
|
||||
|
||||
self.encoder.half()
|
||||
self.encoder.eval()
|
||||
# self.encoder.cuda()
|
||||
self.device = device
|
||||
self.encoder = self.encoder.to(device)
|
||||
|
||||
self.eight_special_token = args.eight_special_token
|
||||
if args.eight_special_token:
|
||||
self.special_tokens = [self.tokenizer.eos_token_id] * 7
|
||||
else:
|
||||
self.special_tokens = []
|
||||
|
||||
self.passage_prompt= args.passage_prompt
|
||||
|
||||
self.batch_size = args.batch_size
|
||||
|
||||
# if self.gpu_count > 1:
|
||||
# self.encoder = torch.nn.DataParallel(self.encoder)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, sentences, **kwargs) -> np.ndarray:
|
||||
""" Returns a list of embeddings for the given sentences.
|
||||
Args:
|
||||
sentences (`List[str]`): List of sentences to encode
|
||||
batch_size (`int`): Batch size for the encoding
|
||||
|
||||
Returns:
|
||||
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
|
||||
"""
|
||||
|
||||
input_texts: List[str] = [self.prompt + s for s in sentences]
|
||||
|
||||
encoded_embeds = []
|
||||
batch_size = self.batch_size * self.gpu_count
|
||||
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10):
|
||||
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size]
|
||||
|
||||
batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens)
|
||||
# if self.device == 0:
|
||||
# print(self.tokenizer.decode(batch_dict['input_ids'][0]))
|
||||
batch_dict = batch_dict.to(self.device)
|
||||
# batch_dict = move_to_cuda(batch_dict)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
||||
if self.l2_normalize:
|
||||
embeds = F.normalize(embeds, p=2, dim=-1)
|
||||
encoded_embeds.append(embeds.cpu().numpy())
|
||||
|
||||
return np.concatenate(encoded_embeds, axis=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_queries(self, sentences, **kwargs) -> np.ndarray:
|
||||
""" Returns a list of embeddings for the given sentences.
|
||||
Args:
|
||||
sentences (`List[str]`): List of sentences to encode
|
||||
batch_size (`int`): Batch size for the encoding
|
||||
|
||||
Returns:
|
||||
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
|
||||
"""
|
||||
|
||||
input_texts: List[str] = [self.prompt + s for s in sentences]
|
||||
|
||||
encoded_embeds = []
|
||||
batch_size = self.batch_size * self.gpu_count
|
||||
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10):
|
||||
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size]
|
||||
|
||||
batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens)
|
||||
batch_dict = move_to_cuda(batch_dict)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
||||
if self.l2_normalize:
|
||||
embeds = F.normalize(embeds, p=2, dim=-1)
|
||||
encoded_embeds.append(embeds.cpu().numpy())
|
||||
|
||||
return np.concatenate(encoded_embeds, axis=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_corpus(self, sentences, **kwargs) -> np.ndarray:
|
||||
""" Returns a list of embeddings for the given sentences.
|
||||
Args:
|
||||
sentences (`List[str]`): List of sentences to encode
|
||||
batch_size (`int`): Batch size for the encoding
|
||||
|
||||
Returns:
|
||||
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
|
||||
"""
|
||||
|
||||
if isinstance(sentences[0], str):
|
||||
input_texts: List[str] = [s for s in sentences]
|
||||
else:
|
||||
input_texts: List[str] = [(s['text'] + ' ' + s['title']).strip() for s in sentences]
|
||||
|
||||
encoded_embeds = []
|
||||
batch_size = self.batch_size * self.gpu_count
|
||||
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10):
|
||||
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size]
|
||||
|
||||
batch_dict = create_batch_dict(self.tokenizer, batch_input_texts, special_tokens=self.special_tokens, passage_prompt=self.passage_prompt)
|
||||
batch_dict = move_to_cuda(batch_dict)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs: BaseModelOutput = self.encoder(**batch_dict)
|
||||
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
|
||||
if self.l2_normalize:
|
||||
embeds = F.normalize(embeds, p=2, dim=-1)
|
||||
encoded_embeds.append(embeds.cpu().numpy())
|
||||
|
||||
return np.concatenate(encoded_embeds, axis=0)
|
||||
|
||||
def set_prompt(self, prompt: str):
|
||||
self.prompt = prompt
|
||||
|
||||
def set_prefix(self, prefix: str):
|
||||
self.prefix = prefix
|
||||
|
||||
def set_suffix(self, suffix: str):
|
||||
self.suffix = suffix
|
||||
|
||||
|
||||
def get_models(model_args: ModelArgs):
|
||||
embedding_model = AutoFlagModel.from_finetuned(
|
||||
model_name_or_path,
|
||||
normalize_embeddings,
|
||||
use_fp16,
|
||||
query_instruction_for_retrieval,
|
||||
query_instruction_format,
|
||||
devices,
|
||||
examples_for_task,
|
||||
examples_instruction_format,
|
||||
trust_remote_code,
|
||||
cache_dir
|
||||
)
|
||||
cross_encoder = None
|
||||
if model_args.reranker is not None:
|
||||
cross_encoder = SentenceTransformerReranker(
|
||||
model_name_or_path,
|
||||
peft_path,
|
||||
use_fp16,
|
||||
use_bf16,
|
||||
query_instruction_for_rerank,
|
||||
query_instruction_format,
|
||||
passage_instruction_for_rerank,
|
||||
passage_instruction_format,
|
||||
cache_dir,
|
||||
trust_remote_code,
|
||||
devices
|
||||
)
|
||||
return embedding_model, cross_encoder
|
||||
def main(device, all_pairs):
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
model = DenseEncoder(device)
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser([ModelArgs, EvalArgs])
|
||||
model_args, eval_args = parser.parse_args_into_dataclasses()
|
||||
model_args: ModelArgs
|
||||
eval_args: EvalArgs
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
|
||||
|
||||
embedding_model, cross_encoder = get_models(model_args)
|
||||
|
||||
evaluation = AIRBench(
|
||||
benchmark_version=eval_args.benchmark_version,
|
||||
task_types=eval_args.task_types,
|
||||
domains=eval_args.domains,
|
||||
languages=eval_args.languages,
|
||||
splits=eval_args.splits,
|
||||
cache_dir=eval_args.cache_dir,
|
||||
)
|
||||
|
||||
retriever = EmbeddingModelRetriever(
|
||||
embedding_model,
|
||||
search_top_k=eval_args.search_top_k,
|
||||
corpus_chunk_size=model_args.corpus_chunk_size,
|
||||
)
|
||||
|
||||
if cross_encoder is not None:
|
||||
reranker = CrossEncoderReranker(
|
||||
cross_encoder,
|
||||
rerank_top_k=eval_args.rerank_top_k,
|
||||
)
|
||||
ArxivClusteringP2P_FLAG = False
|
||||
tmp = None
|
||||
for i, p in enumerate(all_pairs):
|
||||
if p[1] == 'ArxivClusteringP2P':
|
||||
ArxivClusteringP2P_FLAG = True
|
||||
tmp = p
|
||||
break
|
||||
if ArxivClusteringP2P_FLAG:
|
||||
all_pairs.remove(tmp)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
|
||||
if ArxivClusteringP2P_FLAG is False:
|
||||
length = len(all_pairs)
|
||||
start = device * length // ALL_NUM
|
||||
if device == ALL_NUM - 1:
|
||||
end = length
|
||||
else:
|
||||
end = (device + 1) * length // ALL_NUM
|
||||
all_pairs = all_pairs[start: end]
|
||||
else:
|
||||
reranker = None
|
||||
|
||||
evaluation.run(
|
||||
retriever,
|
||||
reranker=reranker,
|
||||
output_dir=eval_args.output_dir,
|
||||
overwrite=eval_args.overwrite,
|
||||
)
|
||||
if device == ALL_NUM - 1:
|
||||
all_pairs = [tmp]
|
||||
else:
|
||||
all_num = ALL_NUM - 1
|
||||
length = len(all_pairs)
|
||||
start = device * length // ALL_NUM
|
||||
if device == all_num - 1:
|
||||
end = length
|
||||
else:
|
||||
end = (device + 1) * length // all_num
|
||||
all_pairs = all_pairs[start: end]
|
||||
|
||||
for (task_type, task_name) in all_pairs:
|
||||
task_def: str = get_task_def_by_task_name_and_type(task_name=task_name, task_type=task_type)
|
||||
prompt: str = get_detailed_instruct(task_def, args.special_token)
|
||||
model.set_prompt(prompt=prompt)
|
||||
|
||||
eg_file_path = f'{args.examples_dir}/{task_name}.csv'
|
||||
eg_paris = []
|
||||
if args.zero_shot:
|
||||
eg_paris = []
|
||||
else:
|
||||
df = pd.read_csv(eg_file_path)
|
||||
for i in range(len(df)):
|
||||
eg_paris.append((get_detailed_instruct(
|
||||
task_def,
|
||||
args.special_token
|
||||
) + df[df.keys()[0]][i], df[df.keys()[1]][i]))
|
||||
if args.special_token:
|
||||
if len(eg_paris) > 0:
|
||||
prefix = '\n\n'.join(['\n<response>'.join(eg_paris[idx]) for idx in range(len(eg_paris))]) + '\n\n'
|
||||
else:
|
||||
prefix = ''
|
||||
suffix = '\n<response>'
|
||||
else:
|
||||
if len(eg_paris) > 0:
|
||||
prefix = '\n\n'.join(['\nResponse: '.join(eg_paris[idx]) for idx in range(len(eg_paris))]) + '\n\n'
|
||||
else:
|
||||
prefix = ''
|
||||
suffix = '\nResponse:'
|
||||
model.set_prefix(prefix)
|
||||
model.set_suffix(suffix)
|
||||
|
||||
logger.info('Set prompt: {}'.format(prompt))
|
||||
|
||||
# disable l2 normalize for classification tasks, as it achieves slightly better results
|
||||
if task_type == 'Classification':
|
||||
logger.info('Set l2_normalize to False for classification task')
|
||||
model.l2_normalize = False
|
||||
else:
|
||||
model.l2_normalize = True
|
||||
logger.info('Set l2_normalize to {}'.format(model.l2_normalize))
|
||||
|
||||
sub_eval = MTEB(tasks=[task_name], task_langs=['en'])
|
||||
|
||||
logger.info('Running evaluation for task: {}, type: {}'.format(task_name, task_type))
|
||||
|
||||
# eval_splits = ["test"] if "test" in task_cls.description["eval_splits"] else task_cls.description["eval_splits"]
|
||||
|
||||
result_flag = False
|
||||
model.batch_size = args.batch_size
|
||||
while result_flag is False:
|
||||
try:
|
||||
sub_eval.run(
|
||||
model,
|
||||
output_folder=args.output_dir
|
||||
)
|
||||
result_flag = True
|
||||
except Exception as e:
|
||||
model.batch_size -= 4
|
||||
print(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
if __name__ == '__main__':
|
||||
processes = []
|
||||
multiprocessing.set_start_method('spawn')
|
||||
|
||||
random.seed(30)
|
||||
args.task_types = [t for t in args.task_types if t.strip()]
|
||||
all_pairs = []
|
||||
for task_type in args.task_types:
|
||||
if task_type in tasks_desc.keys():
|
||||
for task_name in tasks_desc[task_type]:
|
||||
all_pairs.append((task_type, task_name))
|
||||
for task_type in tasks_desc.keys():
|
||||
for v in tasks_desc[task_type]:
|
||||
if v in args.task_types:
|
||||
all_pairs.append((task_type, v))
|
||||
all_pairs = list(set(all_pairs))
|
||||
random.shuffle(all_pairs)
|
||||
|
||||
for i in range(ALL_NUM):
|
||||
# i = 7
|
||||
process = multiprocessing.Process(target=main, args=(i,all_pairs,))
|
||||
processes.append(process)
|
||||
process.start()
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
365
FlagEmbedding/evaluation/mteb/utils.py
Normal file
365
FlagEmbedding/evaluation/mteb/utils.py
Normal file
@ -0,0 +1,365 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerFast, BatchEncoding
|
||||
from typing import Mapping, Dict, List
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(log_format)
|
||||
logger.handlers = [console_handler]
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = _setup_logger()
|
||||
|
||||
|
||||
def move_to_cuda(sample):
|
||||
if len(sample) == 0:
|
||||
return {}
|
||||
|
||||
def _move_to_cuda(maybe_tensor):
|
||||
if torch.is_tensor(maybe_tensor):
|
||||
return maybe_tensor.cuda(non_blocking=True)
|
||||
elif isinstance(maybe_tensor, dict):
|
||||
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
|
||||
elif isinstance(maybe_tensor, list):
|
||||
return [_move_to_cuda(x) for x in maybe_tensor]
|
||||
elif isinstance(maybe_tensor, tuple):
|
||||
return tuple([_move_to_cuda(x) for x in maybe_tensor])
|
||||
elif isinstance(maybe_tensor, Mapping):
|
||||
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
|
||||
else:
|
||||
return maybe_tensor
|
||||
|
||||
return _move_to_cuda(sample)
|
||||
|
||||
|
||||
def pool(last_hidden_states: Tensor,
|
||||
attention_mask: Tensor,
|
||||
pool_type: str) -> Tensor:
|
||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
|
||||
if pool_type == "avg":
|
||||
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
elif pool_type == "weightedavg": # position-weighted mean pooling from SGPT (https://arxiv.org/abs/2202.08904)
|
||||
attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
|
||||
s = torch.sum(last_hidden * attention_mask.unsqueeze(-1).float(), dim=1)
|
||||
d = attention_mask.sum(dim=1, keepdim=True).float()
|
||||
emb = s / d
|
||||
elif pool_type == "cls":
|
||||
emb = last_hidden[:, 0]
|
||||
elif pool_type == "last":
|
||||
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
||||
if left_padding:
|
||||
emb = last_hidden[:, -1]
|
||||
else:
|
||||
sequence_lengths = attention_mask.sum(dim=1) - 1
|
||||
batch_size = last_hidden.shape[0]
|
||||
emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
|
||||
elif pool_type == "last_eight":
|
||||
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
||||
if left_padding:
|
||||
emb = torch.mean(last_hidden[:, -8:, :], dim=-2)
|
||||
else:
|
||||
sys.exit()
|
||||
else:
|
||||
raise ValueError(f"pool_type {pool_type} not supported")
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
def create_batch_dict(tokenizer: PreTrainedTokenizerFast, input_texts: List[str], max_length: int = 512, special_tokens: list = [], passage_prompt: bool = False) -> BatchEncoding:
|
||||
passage_suffix = ''
|
||||
passage_suffix_len = 0
|
||||
if passage_prompt:
|
||||
passage_suffix = '\nSummarize the above passage: '
|
||||
passage_suffix_len = len(tokenizer(passage_suffix, add_special_tokens=False)['input_ids'])
|
||||
inputs = tokenizer(
|
||||
input_texts,
|
||||
max_length=max_length - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
|
||||
tokenizer('</s>', add_special_tokens=False)['input_ids']) - len(special_tokens) - passage_suffix_len,
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False
|
||||
)
|
||||
input_texts = tokenizer.batch_decode(inputs['input_ids'])
|
||||
if len(special_tokens) > 0:
|
||||
for i in range(len(input_texts)):
|
||||
input_texts[i] = input_texts[i] + passage_suffix + tokenizer.eos_token * 7
|
||||
else:
|
||||
for i in range(len(input_texts)):
|
||||
input_texts[i] = input_texts[i] + passage_suffix
|
||||
|
||||
return tokenizer(
|
||||
input_texts,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
pad_to_multiple_of=8,
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
def create_batch_query_dict(tokenizer: PreTrainedTokenizerFast, prefix: str, suffix: str, input_texts: List[str], max_length: int = 512, special_tokens: list = []) -> BatchEncoding:
|
||||
inputs = tokenizer(
|
||||
input_texts,
|
||||
max_length=max_length - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']) - len(special_tokens),
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False
|
||||
)
|
||||
prefix_ids = tokenizer(prefix, add_special_tokens=False)['input_ids']
|
||||
suffix_ids = tokenizer(suffix, add_special_tokens=False)['input_ids']
|
||||
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length) // 8 * 8 + 8
|
||||
|
||||
input_texts = tokenizer.batch_decode(inputs['input_ids'])
|
||||
for i in range(len(input_texts)):
|
||||
if len(special_tokens) > 0:
|
||||
input_texts[i] = prefix + input_texts[i] + suffix + tokenizer.eos_token * 7
|
||||
else:
|
||||
input_texts[i] = prefix + input_texts[i] + suffix
|
||||
|
||||
return tokenizer(
|
||||
input_texts,
|
||||
max_length=new_max_length,
|
||||
padding=True,
|
||||
pad_to_multiple_of=8,
|
||||
return_token_type_ids=False,
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
|
||||
def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str:
|
||||
if task_type in ['STS']:
|
||||
return "Retrieve semantically similar text."
|
||||
|
||||
if task_type in ['Summarization']:
|
||||
return "Given a news summary, retrieve other semantically similar summaries."
|
||||
|
||||
if task_type in ['BitextMining']:
|
||||
return "Retrieve parallel sentences."
|
||||
|
||||
if task_type in ['Classification']:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual.',
|
||||
'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment.',
|
||||
'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category.',
|
||||
'Banking77Classification': 'Given a online banking query, find the corresponding intents.',
|
||||
'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.',
|
||||
'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset.',
|
||||
'MassiveIntentClassification': 'Given a user utterance as query, find the user intents.',
|
||||
'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios.',
|
||||
'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation.',
|
||||
'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation.',
|
||||
'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic.',
|
||||
'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral.',
|
||||
# C-MTEB eval instructions
|
||||
'TNews': 'Classify the fine-grained category of the given news title.',
|
||||
'IFlyTek': 'Given an App description text, find the appropriate fine-grained category.',
|
||||
'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative.',
|
||||
'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative.',
|
||||
'OnlineShopping': 'Classify the customer review for online shopping into positive or negative.',
|
||||
'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative.',
|
||||
}
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
if task_type in ['Clustering']:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts.',
|
||||
'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles.',
|
||||
'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts.',
|
||||
'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles.',
|
||||
'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts.',
|
||||
'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles.',
|
||||
'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles.',
|
||||
'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts.',
|
||||
'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles.',
|
||||
'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs.',
|
||||
'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles.',
|
||||
# C-MTEB eval instructions
|
||||
'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles.',
|
||||
'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts.',
|
||||
'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles.',
|
||||
'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents.',
|
||||
}
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
if task_type in ['Reranking', 'PairClassification']:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum.',
|
||||
'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history.',
|
||||
'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers.',
|
||||
'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum.',
|
||||
'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum.',
|
||||
'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet.',
|
||||
'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet.',
|
||||
# C-MTEB eval instructions
|
||||
'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
|
||||
'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
|
||||
'Ocnli': 'Retrieve semantically similar text.',
|
||||
'Cmnli': 'Retrieve semantically similar text.',
|
||||
}
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
if task_type in ['Retrieval']:
|
||||
if task_name.lower().startswith('cqadupstack'):
|
||||
return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.'
|
||||
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'ArguAna': 'Given a claim, find documents that refute the claim.',
|
||||
'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim.',
|
||||
'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia.',
|
||||
'FEVER': 'Given a claim, retrieve documents that support or refute the claim.',
|
||||
'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question.',
|
||||
'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question.',
|
||||
'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question.',
|
||||
'NQ': 'Given a question, retrieve Wikipedia passages that answer the question.',
|
||||
'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question.',
|
||||
'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.',
|
||||
'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim.',
|
||||
'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.',
|
||||
'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query.',
|
||||
# C-MTEB eval instructions
|
||||
'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question.',
|
||||
'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
|
||||
'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products.',
|
||||
'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question.',
|
||||
'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos.',
|
||||
}
|
||||
|
||||
# add lower case keys to match some beir names
|
||||
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
|
||||
# other cases where lower case match still doesn't work
|
||||
task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
|
||||
task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
|
||||
task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
|
||||
task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
|
||||
task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
|
||||
task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
|
||||
|
||||
# for miracl evaluation
|
||||
task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question.'
|
||||
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
raise ValueError(f"No instruction config for task {task_name} with type {task_type}")
|
||||
|
||||
|
||||
def get_detailed_instruct(task_description: str, special_token: bool) -> str:
|
||||
if not task_description:
|
||||
return ''
|
||||
|
||||
if special_token:
|
||||
return f'<instruct>{task_description}\n<query>'
|
||||
else:
|
||||
return f'Instruct: {task_description}\nQuery: '
|
||||
|
||||
|
||||
tasks_desc = {
|
||||
'Retrieval': [
|
||||
'ArguAna',
|
||||
'ClimateFEVER',
|
||||
'DBPedia',
|
||||
'FEVER',
|
||||
'FiQA2018',
|
||||
'HotpotQA',
|
||||
'MSMARCO',
|
||||
'NFCorpus',
|
||||
'NQ',
|
||||
'QuoraRetrieval',
|
||||
'SCIDOCS',
|
||||
'SciFact',
|
||||
'Touche2020',
|
||||
'TRECCOVID',
|
||||
'CQADupstackAndroidRetrieval.py',
|
||||
'CQADupstackEnglishRetrieval.py',
|
||||
'CQADupstackGamingRetrieval',
|
||||
'CQADupstackGisRetrieval',
|
||||
'CQADupstackMathematicaRetrieval',
|
||||
'CQADupstackPhysicsRetrieval',
|
||||
'CQADupstackProgrammersRetrieval',
|
||||
'CQADupstackStatsRetrieval',
|
||||
'CQADupstackTexRetrieval',
|
||||
'CQADupstackUnixRetrieval',
|
||||
'CQADupstackWebmastersRetrieval',
|
||||
'CQADupstackWordpressRetrieval'
|
||||
],
|
||||
'Classification': [
|
||||
# 12
|
||||
'AmazonCounterfactualClassification',
|
||||
'AmazonPolarityClassification',
|
||||
'AmazonReviewsClassification',
|
||||
'Banking77Classification',
|
||||
'EmotionClassification',
|
||||
'ImdbClassification',
|
||||
'MassiveIntentClassification',
|
||||
'MassiveScenarioClassification',
|
||||
'MTOPDomainClassification',
|
||||
'MTOPIntentClassification',
|
||||
'ToxicConversationsClassification',
|
||||
'TweetSentimentExtractionClassification',
|
||||
],
|
||||
'Clustering': [
|
||||
# 11
|
||||
'ArxivClusteringP2P',
|
||||
'ArxivClusteringS2S',
|
||||
'BiorxivClusteringP2P',
|
||||
'BiorxivClusteringS2S',
|
||||
'MedrxivClusteringP2P',
|
||||
'MedrxivClusteringS2S',
|
||||
'RedditClustering',
|
||||
'RedditClusteringP2P',
|
||||
'StackExchangeClustering',
|
||||
'StackExchangeClusteringP2P',
|
||||
'TwentyNewsgroupsClustering',
|
||||
],
|
||||
'PairClassification': [
|
||||
# 3
|
||||
'SprintDuplicateQuestions',
|
||||
'TwitterSemEval2015',
|
||||
'TwitterURLCorpus',
|
||||
],
|
||||
'Reranking': [
|
||||
# 4
|
||||
'AskUbuntuDupQuestions',
|
||||
'MindSmallReranking',
|
||||
'SciDocsRR',
|
||||
'StackOverflowDupQuestions',
|
||||
],
|
||||
'STS': [
|
||||
# 10
|
||||
'BIOSSES',
|
||||
'SICK-R',
|
||||
'STS12',
|
||||
'STS13',
|
||||
'STS14',
|
||||
'STS15',
|
||||
'STS16',
|
||||
'STS17',
|
||||
'STS22',
|
||||
'STSBenchmark',
|
||||
],
|
||||
'Summarization': [
|
||||
# 1
|
||||
'SummEval',
|
||||
]
|
||||
}
|
||||
|
Can't render this file because it contains an unexpected character in line 4 and column 86.
|
|
Can't render this file because it contains an unexpected character in line 5 and column 6.
|
|
Can't render this file because it contains an unexpected character in line 5 and column 88.
|
216
FlagEmbedding/evaluation/mteb/utils/prompts.py
Normal file
216
FlagEmbedding/evaluation/mteb/utils/prompts.py
Normal file
@ -0,0 +1,216 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str:
|
||||
if task_type in ['STS']:
|
||||
return "Retrieve semantically similar text."
|
||||
|
||||
if task_type in ['Summarization']:
|
||||
return "Given a news summary, retrieve other semantically similar summaries."
|
||||
|
||||
if task_type in ['BitextMining']:
|
||||
return "Retrieve parallel sentences."
|
||||
|
||||
if task_type in ['Classification']:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual.',
|
||||
'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment.',
|
||||
'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category.',
|
||||
'Banking77Classification': 'Given a online banking query, find the corresponding intents.',
|
||||
'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.',
|
||||
'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset.',
|
||||
'MassiveIntentClassification': 'Given a user utterance as query, find the user intents.',
|
||||
'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios.',
|
||||
'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation.',
|
||||
'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation.',
|
||||
'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic.',
|
||||
'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral.',
|
||||
# C-MTEB eval instructions
|
||||
'TNews': 'Classify the fine-grained category of the given news title.',
|
||||
'IFlyTek': 'Given an App description text, find the appropriate fine-grained category.',
|
||||
'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative.',
|
||||
'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative.',
|
||||
'OnlineShopping': 'Classify the customer review for online shopping into positive or negative.',
|
||||
'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative.',
|
||||
}
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
if task_type in ['Clustering']:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts.',
|
||||
'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles.',
|
||||
'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts.',
|
||||
'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles.',
|
||||
'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts.',
|
||||
'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles.',
|
||||
'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles.',
|
||||
'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts.',
|
||||
'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles.',
|
||||
'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs.',
|
||||
'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles.',
|
||||
# C-MTEB eval instructions
|
||||
'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles.',
|
||||
'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts.',
|
||||
'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles.',
|
||||
'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents.',
|
||||
}
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
if task_type in ['Reranking', 'PairClassification']:
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum.',
|
||||
'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history.',
|
||||
'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers.',
|
||||
'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum.',
|
||||
'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum.',
|
||||
'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet.',
|
||||
'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet.',
|
||||
# C-MTEB eval instructions
|
||||
'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
|
||||
'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
|
||||
'Ocnli': 'Retrieve semantically similar text.',
|
||||
'Cmnli': 'Retrieve semantically similar text.',
|
||||
}
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
if task_type in ['Retrieval']:
|
||||
if task_name.lower().startswith('cqadupstack'):
|
||||
return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.'
|
||||
|
||||
task_name_to_instruct: Dict[str, str] = {
|
||||
'ArguAna': 'Given a claim, find documents that refute the claim.',
|
||||
'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim.',
|
||||
'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia.',
|
||||
'FEVER': 'Given a claim, retrieve documents that support or refute the claim.',
|
||||
'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question.',
|
||||
'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question.',
|
||||
'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question.',
|
||||
'NQ': 'Given a question, retrieve Wikipedia passages that answer the question.',
|
||||
'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question.',
|
||||
'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.',
|
||||
'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim.',
|
||||
'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.',
|
||||
'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query.',
|
||||
# C-MTEB eval instructions
|
||||
'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
|
||||
'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question.',
|
||||
'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
|
||||
'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products.',
|
||||
'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question.',
|
||||
'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos.',
|
||||
}
|
||||
|
||||
# add lower case keys to match some beir names
|
||||
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
|
||||
# other cases where lower case match still doesn't work
|
||||
task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
|
||||
task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
|
||||
task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
|
||||
task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
|
||||
task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
|
||||
task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
|
||||
|
||||
# for miracl evaluation
|
||||
task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question.'
|
||||
|
||||
return task_name_to_instruct[task_name]
|
||||
|
||||
raise ValueError(f"No instruction config for task {task_name} with type {task_type}")
|
||||
|
||||
|
||||
tasks_desc = {
|
||||
'Retrieval': [
|
||||
'ArguAna',
|
||||
'ClimateFEVER',
|
||||
'DBPedia',
|
||||
'FEVER',
|
||||
'FiQA2018',
|
||||
'HotpotQA',
|
||||
'MSMARCO',
|
||||
'NFCorpus',
|
||||
'NQ',
|
||||
'QuoraRetrieval',
|
||||
'SCIDOCS',
|
||||
'SciFact',
|
||||
'Touche2020',
|
||||
'TRECCOVID',
|
||||
'CQADupstackAndroidRetrieval',
|
||||
'CQADupstackEnglishRetrieval',
|
||||
'CQADupstackGamingRetrieval',
|
||||
'CQADupstackGisRetrieval',
|
||||
'CQADupstackMathematicaRetrieval',
|
||||
'CQADupstackPhysicsRetrieval',
|
||||
'CQADupstackProgrammersRetrieval',
|
||||
'CQADupstackStatsRetrieval',
|
||||
'CQADupstackTexRetrieval',
|
||||
'CQADupstackUnixRetrieval',
|
||||
'CQADupstackWebmastersRetrieval',
|
||||
'CQADupstackWordpressRetrieval'
|
||||
],
|
||||
'Classification': [
|
||||
# 12
|
||||
'AmazonCounterfactualClassification',
|
||||
'AmazonPolarityClassification',
|
||||
'AmazonReviewsClassification',
|
||||
'Banking77Classification',
|
||||
'EmotionClassification',
|
||||
'ImdbClassification',
|
||||
'MassiveIntentClassification',
|
||||
'MassiveScenarioClassification',
|
||||
'MTOPDomainClassification',
|
||||
'MTOPIntentClassification',
|
||||
'ToxicConversationsClassification',
|
||||
'TweetSentimentExtractionClassification',
|
||||
],
|
||||
'Clustering': [
|
||||
# 11
|
||||
'ArxivClusteringP2P',
|
||||
'ArxivClusteringS2S',
|
||||
'BiorxivClusteringP2P',
|
||||
'BiorxivClusteringS2S',
|
||||
'MedrxivClusteringP2P',
|
||||
'MedrxivClusteringS2S',
|
||||
'RedditClustering',
|
||||
'RedditClusteringP2P',
|
||||
'StackExchangeClustering',
|
||||
'StackExchangeClusteringP2P',
|
||||
'TwentyNewsgroupsClustering',
|
||||
],
|
||||
'PairClassification': [
|
||||
# 3
|
||||
'SprintDuplicateQuestions',
|
||||
'TwitterSemEval2015',
|
||||
'TwitterURLCorpus',
|
||||
],
|
||||
'Reranking': [
|
||||
# 4
|
||||
'AskUbuntuDupQuestions',
|
||||
'MindSmallReranking',
|
||||
'SciDocsRR',
|
||||
'StackOverflowDupQuestions',
|
||||
],
|
||||
'STS': [
|
||||
# 10
|
||||
'BIOSSES',
|
||||
'SICK-R',
|
||||
'STS12',
|
||||
'STS13',
|
||||
'STS14',
|
||||
'STS15',
|
||||
'STS16',
|
||||
'STS17',
|
||||
'STS22',
|
||||
'STSBenchmark',
|
||||
],
|
||||
'Summarization': [
|
||||
# 1
|
||||
'SummEval',
|
||||
]
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user