eval beir

This commit is contained in:
cfli 2024-10-24 15:48:21 +08:00
parent 6d71e83cf2
commit 00a42ccd4f
90 changed files with 1272 additions and 160 deletions

View File

@ -0,0 +1,96 @@
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker
from utils.arguments import BEIREvalArgs
from utils.data_loader import BEIRDataLoader
from utils.evaluator import BEIREvaluator
def get_models(model_args: AbsModelArgs):
retriever = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir
)
reranker = None
if model_args.reranker_name_or_path is not None:
reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_args.reranker_name_or_path,
peft_path=model_args.reranker_peft_path,
use_fp16=model_args.use_fp16,
use_bf16=model_args.use_bf16,
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
query_instruction_format=model_args.query_instruction_format_for_rerank,
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
devices=model_args.devices,
normalize=model_args.normalize,
prompt=model_args.prompt,
cutoff_layers=model_args.cutoff_layers,
compress_layers=model_args.compress_layers,
compress_ratio=model_args.compress_ratio,
)
return retriever, reranker
def main():
parser = HfArgumentParser([AbsModelArgs, BEIREvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: AbsModelArgs
eval_args: BEIREvalArgs
retriever, reranker = get_models(model_args)
retriever = AbsEmbedder(
retriever,
search_top_k=eval_args.search_top_k,
)
if reranker is not None:
reranker = AbsReranker(
reranker,
rerank_top_k=eval_args.rerank_top_k,
)
else:
reranker = None
for dataset_name in eval_args.dataset_names:
data_loader = BEIRDataLoader(
dataset_dir = eval_args.dataset_dir,
cache_dir = eval_args.cache_path,
dataset_name=dataset_name
)
evaluation = BEIREvaluator(
data_loader=data_loader,
overwrite=eval_args.overwrite,
)
evaluation(
splits=eval_args.splits.split(),
search_results_save_dir=eval_args.output_dir,
retriever=retriever,
reranker=reranker,
corpus_embd_save_dir=eval_args.corpus_embd_save_dir,
retriever_batch_size=model_args.retriever_batch_size,
reranker_batch_size=model_args.reranker_batch_size,
retriever_query_max_length=model_args.retriever_query_max_length,
retriever_passage_max_length=model_args.retriever_passage_max_length,
reranker_max_length=model_args.reranker_max_length,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,12 @@
python __main__.py \
--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR \
--embedder_name_or_path BAAI/bge-large-en-v1.5 \
--reranker_name_or_path BAAI/bge-reranker-large \
--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: " \
--use_fp16 True \
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
--cache_dir /share/shared_models \
--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR_passage_embds \
--reranker_max_length 512 \
--dataset_names nfcorpus fiqa cqadupstack

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import List
from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs

View File

@ -15,7 +15,7 @@ from FlagEmbedding.abc.evaluation.data_loader import AbsDataLoader
logger = logging.getLogger(__name__)
class MSMARCODataLoader(AbsDataLoader):
class BEIRDataLoader(AbsDataLoader):
def __init__(
self,
dataset_dir: str, # the dataset dir to load from
@ -32,16 +32,15 @@ class MSMARCODataLoader(AbsDataLoader):
self.split = 'test'
if dataset_name == 'msmarco': self.split = 'dev'
os.makedirs(self.dataset_dir, exist_ok=True)
if dataset_name != 'cqadupstack':
os.makedirs(self.dataset_dir, exist_ok=True)
qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split))
corpus_path = os.path.join(self.dataset_dir, 'corpus.json')
queries_path = os.path.join(self.dataset_dir, 'queries-{split}.json'.format(split=self.split))
queries, corpus, rels = {}, {}, {}
if not os.path.exists(corpus_path):
dataset = datasets.load_dataset(
'BeIR/{d}'.format(d=data_name),
'BeIR/{d}'.format(d=dataset_name),
'corpus',
trust_remote_code=True,
cache_dir=os.getenv('HF_HUB_CACHE', cache_dir),
@ -55,7 +54,7 @@ class MSMARCODataLoader(AbsDataLoader):
if not os.path.exists(queries_path):
dataset = datasets.load_dataset(
'BeIR/{d}'.format(d=data_name),
'BeIR/{d}'.format(d=dataset_name),
'queries',
trust_remote_code=True,
cache_dir=os.getenv('HF_HUB_CACHE', cache_dir),
@ -69,7 +68,7 @@ class MSMARCODataLoader(AbsDataLoader):
if not os.path.exists(qrels_path):
dataset = datasets.load_dataset(
'BeIR/{d}-qrels'.format(d=data_name),
'BeIR/{d}-qrels'.format(d=dataset_name),
split=self.split,
trust_remote_code=True,
cache_dir=os.getenv('HF_HUB_CACHE', cache_dir),
@ -93,10 +92,8 @@ class MSMARCODataLoader(AbsDataLoader):
if not os.path.exists(corpus_path) or not os.path.exists(queries_path) or not os.path.exists(qrels_path):
full_path = os.path.join(data_path, sub_dataset_name)
corpus, queries, qrels = GenericDataLoader(data_folder=full_path).load(split="test")
for k in corpus_data.keys():
corpus[k] = {
(corpus[k]['title'] + ' ' + corpus[k]['text']).strip()
}
for k in corpus.keys():
corpus[k] = (corpus[k]['title'] + ' ' + corpus[k]['text']).strip()
with open(corpus_path, 'w') as f:
json.dump(corpus, f)
@ -105,25 +102,40 @@ class MSMARCODataLoader(AbsDataLoader):
json.dump(queries, f)
with open(qrels_path, 'w') as f:
json.dump(rels, f)
json.dump(qrels, f)
def load_qrels(self):
if self.sub_dataset_name is None:
def load_qrels(self, sub_dataset_name: str = None, split: str = None):
if sub_dataset_name is None and split is not None:
sub_dataset_name = split.split('-')
if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0:
sub_dataset_name = None
else:
sub_dataset_name = sub_dataset_name[0].strip()
if sub_dataset_name is None:
qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split))
rels = json.load(open(qrels_path))
return datasets.DatasetDict(rels)
else:
rels_list = []
for sub_dataset_name in self.sub_dataset_names:
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
rels = json.load(open(qrels_path))
rels_list.append(datasets.DatasetDict(rels))
return rels_list
# rels_list = []
# for sub_dataset_name in self.sub_dataset_names:
# qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
# rels = json.load(open(qrels_path))
# rels_list.append(datasets.DatasetDict(rels))
# return rels_list
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
rels = json.load(open(qrels_path))
return datasets.DatasetDict(rels)
def load_corpus(self):
if self.sub_dataset_name is None:
def load_corpus(self, sub_dataset_name: str = None, split: str = None):
if sub_dataset_name is None and split is not None:
sub_dataset_name = split.split('-')
if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0:
sub_dataset_name = None
else:
sub_dataset_name = sub_dataset_name[0].strip()
if sub_dataset_name is None:
corpus_path = os.path.join(self.dataset_dir, 'corpus.json')
corpus = json.load(open(corpus_path))
for k in corpus.keys():
@ -132,15 +144,28 @@ class MSMARCODataLoader(AbsDataLoader):
}
return datasets.DatasetDict(corpus)
else:
corpus_list = []
for sub_dataset_name in self.sub_dataset_names:
corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json')
corpus = json.load(open(corpus_path))
corpus_list.append(datasets.DatasetDict(corpus))
return corpus_list
corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json')
corpus = json.load(open(corpus_path))
for k in corpus.keys():
corpus[k] = {
'text': corpus[k]
}
return datasets.DatasetDict(corpus)
# corpus_list = []
# for sub_dataset_name in self.sub_dataset_names:
# corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json')
# corpus = json.load(open(corpus_path))
# corpus_list.append(datasets.DatasetDict(corpus))
# return corpus_list
def load_queries(self, split='dev'):
if self.sub_dataset_name is None:
def load_queries(self, sub_dataset_name: str = None, split: str = None):
if sub_dataset_name is None and split is not None:
sub_dataset_name = split.split('-')
if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0:
sub_dataset_name = None
else:
sub_dataset_name = sub_dataset_name[0].strip()
if sub_dataset_name is None:
queries_path = os.path.join(self.dataset_dir, 'queries-{split}.json'.format(split=self.split))
qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split))
queries = json.load(open(queries_path))
@ -151,15 +176,25 @@ class MSMARCODataLoader(AbsDataLoader):
new_queries[k] = queries[k]
return datasets.DatasetDict(new_queries)
else:
queries_list = []
for sub_dataset_name in self.sub_dataset_names:
queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split))
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
queries = json.load(open(queries_path))
rels = json.load(open(qrels_path))
new_queries = {}
for k in queries.keys():
if k in rels.keys():
new_queries[k] = queries[k]
queries_list.append(datasets.DatasetDict(new_queries))
return queries_list
queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split))
qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
queries = json.load(open(queries_path))
print(qrels_path)
rels = json.load(open(qrels_path))
new_queries = {}
for k in queries.keys():
if k in rels.keys():
new_queries[k] = queries[k]
return datasets.DatasetDict(new_queries)
# queries_list = []
# for sub_dataset_name in self.sub_dataset_names:
# queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split))
# qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split))
# queries = json.load(open(queries_path))
# rels = json.load(open(qrels_path))
# new_queries = {}
# for k in queries.keys():
# if k in rels.keys():
# new_queries[k] = queries[k]
# queries_list.append(datasets.DatasetDict(new_queries))
# return queries_list

View File

@ -6,6 +6,8 @@ import pandas as pd
from typing import Dict, Optional, List, Union
from FlagEmbedding.abc.evaluation.evaluator import AbsEvaluator
from FlagEmbedding.abc.evaluation.data_loader import AbsDataLoader
from FlagEmbedding.abc.evaluation.searcher import AbsEmbedder, AbsReranker
class BEIREvaluator(AbsEvaluator):
def __init__(
@ -33,50 +35,55 @@ class BEIREvaluator(AbsEvaluator):
):
dataset_name = self.data_loader.dataset_name
sub_dataset_names = self.data_loader.sub_dataset_names
if sub_dataset_name
split = self.data_loader.split
if isinstance(splits, str):
splits = [splits]
# Retrieval Stage
no_reranker_search_results_save_dir = os.path.join(
search_results_save_dir, str(retriever), "NoReranker"
search_results_save_dir, str(retriever), "NoReranker", dataset_name
)
os.makedirs(no_reranker_search_results_save_dir, exist_ok=True)
if corpus_embd_save_dir is not None: corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, dataset_name)
flag = False
if sub_dataset_names is None:
split_no_reranker_search_results_save_path = os.path.join(
no_reranker_search_results_save_dir, f"{dataset_name}.json"
no_reranker_search_results_save_dir, f"{split}.json"
)
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
flag = True
break
else:
for sub_dataset_name
for sub_dataset_name in sub_dataset_names:
split_no_reranker_search_results_save_path = os.path.join(
no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
)
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
flag = True
break
no_reranker_search_results_dict = {}
if flag:
corpus = self.data_loader.load_corpus()
if sub_dataset_names is None:
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_names)
queries_dict = {
split: self.data_loader.load_queries(split=split)
for split in splits
}
queries_dict = {
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_names)
}
all_queries = {}
for _, split_queries in queries_dict.items():
all_queries.update(split_queries)
all_queries = {}
for _, split_queries in queries_dict.items():
all_queries.update(split_queries)
all_no_reranker_search_results = retriever(
corpus=corpus,
queries=all_queries,
corpus_embd_save_dir=corpus_embd_save_dir,
batch_size=retriever_batch_size,
query_max_length=retriever_query_max_length,
passage_max_length=retriever_passage_max_length,
**kwargs,
)
all_no_reranker_search_results = retriever(
corpus=corpus,
queries=all_queries,
corpus_embd_save_dir=corpus_embd_save_dir,
batch_size=retriever_batch_size,
query_max_length=retriever_query_max_length,
passage_max_length=retriever_passage_max_length,
**kwargs,
)
for split in splits:
split_queries = queries_dict[split]
no_reranker_search_results_dict[split] = {
qid: all_no_reranker_search_results[qid] for qid in split_queries
@ -93,8 +100,47 @@ class BEIREvaluator(AbsEvaluator):
split=split,
dataset_dir=self.dataset_dir,
)
else:
for sub_dataset_name in sub_dataset_names:
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_name)
queries_dict = {
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_name)
}
all_queries = {}
for _, split_queries in queries_dict.items():
all_queries.update(split_queries)
all_no_reranker_search_results = retriever(
corpus=corpus,
queries=all_queries,
corpus_embd_save_dir=None if corpus_embd_save_dir is None else os.path.join(corpus_embd_save_dir, sub_dataset_name),
batch_size=retriever_batch_size,
query_max_length=retriever_query_max_length,
passage_max_length=retriever_passage_max_length,
**kwargs,
)
split_queries = queries_dict[split]
no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"] = {
qid: all_no_reranker_search_results[qid] for qid in split_queries
}
split_no_reranker_search_results_save_path = os.path.join(
no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
)
self.save_search_results(
model_name=str(retriever),
reranker_name="NoReranker",
search_results=no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"],
output_path=split_no_reranker_search_results_save_path,
split=f"{sub_dataset_name}-{split}",
dataset_dir=self.dataset_dir,
)
else:
for split in splits:
if sub_dataset_names is None:
split_no_reranker_search_results_save_path = os.path.join(
no_reranker_search_results_save_dir, f"{split}.json"
)
@ -107,47 +153,93 @@ class BEIREvaluator(AbsEvaluator):
split=split,
)
no_reranker_search_results_dict[split] = search_results
else:
for sub_dataset_name in sub_dataset_names:
split_no_reranker_search_results_save_path = os.path.join(
no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
)
data_info, search_results = self.load_search_results(split_no_reranker_search_results_save_path)
self.check_data_info(
data_info=data_info,
model_name=str(retriever),
reranker_name="NoReranker",
split=f"{sub_dataset_name}-{split}",
)
no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"] = search_results
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir)
self.output_eval_results_to_json(retriever_eval_results, os.path.join(no_reranker_search_results_save_dir, 'eval.json'))
# Reranking Stage
if reranker is not None:
reranker_search_results_save_dir = os.path.join(
search_results_save_dir, str(retriever), str(reranker)
search_results_save_dir, str(retriever), str(reranker), dataset_name
)
os.makedirs(reranker_search_results_save_dir, exist_ok=True)
corpus = self.data_loader.load_corpus()
if sub_dataset_names is None:
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_names)
queries_dict = {
split: self.data_loader.load_queries(split=split)
for split in splits
}
queries_dict = {
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_names)
}
for split in splits:
rerank_search_results_save_path = os.path.join(
reranker_search_results_save_dir, f"{split}.json"
)
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
continue
pass
else:
rerank_search_results = reranker(
corpus=corpus,
queries=queries_dict[split],
search_results=no_reranker_search_results_dict[split],
batch_size=reranker_batch_size,
max_length=reranker_max_length,
**kwargs,
)
rerank_search_results = reranker(
corpus=corpus,
queries=queries_dict[split],
search_results=no_reranker_search_results_dict[split],
batch_size=reranker_batch_size,
max_length=reranker_max_length,
**kwargs,
)
self.save_search_results(
model_name=str(retriever),
reranker_name=str(reranker),
search_results=rerank_search_results,
output_path=rerank_search_results_save_path,
split=split,
dataset_dir=self.dataset_dir,
)
else:
for sub_dataset_name in sub_dataset_names:
corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_name)
queries_dict = {
split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_name)
}
rerank_search_results_save_path = os.path.join(
reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json"
)
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
continue
rerank_search_results = reranker(
corpus=corpus,
queries=queries_dict[split],
search_results=no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"],
batch_size=reranker_batch_size,
max_length=reranker_max_length,
**kwargs,
)
self.save_search_results(
model_name=str(retriever),
reranker_name=str(reranker),
search_results=rerank_search_results,
output_path=rerank_search_results_save_path,
split=f"{sub_dataset_name}-{split}",
dataset_dir=self.dataset_dir,
)
self.save_search_results(
model_name=str(retriever),
reranker_name=str(reranker),
search_results=rerank_search_results,
output_path=rerank_search_results_save_path,
split=split,
dataset_dir=self.dataset_dir,
)
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir)
self.output_eval_results_to_json(reranker_eval_results, os.path.join(reranker_search_results_save_dir, 'eval.json'))

View File

@ -0,0 +1,44 @@
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
from utils.arguments import MSMARCOEvalArgs
from utils.data_loader import MSMARCODataLoader
def get_models(model_args: AbsModelArgs):
retriever = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir
)
reranker = None
if model_args.reranker_name_or_path is not None:
reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_args.reranker_name_or_path,
peft_path=model_args.reranker_peft_path,
use_fp16=model_args.use_fp16,
use_bf16=model_args.use_bf16,
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
query_instruction_format=model_args.query_instruction_format_for_rerank,
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
devices=model_args.devices,
normalize=model_args.normalize,
prompt=model_args.prompt,
cutoff_layers=model_args.cutoff_layers,
compress_layers=model_args.compress_layers,
compress_ratio=model_args.compress_ratio,
)
return retriever, reranker

View File

@ -1,81 +1,329 @@
from transformers import HfArgumentParser
import multiprocessing
import os
import random
import sys
from FlagEmbedding import AutoFlagModel
import pandas as pd
import torch
import torch.nn.functional as F
import tqdm
import json
import numpy as np
import argparse
from utils.arguments import ModelArgs
from utils.models import SentenceTransformerEncoder, SentenceTransformerReranker
from utils.searcher import EmbeddingModelRetriever, CrossEncoderReranker
import mteb
from peft import PeftModel
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput
from typing import List
from mteb import MTEB
from utils import logger, pool, move_to_cuda, get_detailed_instruct, get_task_def_by_task_name_and_type, create_batch_dict, tasks_desc, create_batch_query_dict
parser = argparse.ArgumentParser(description='evaluation for MTEB benchmark except its Retrieval category')
parser.add_argument('--task-types', nargs='+', default=[], help='task types to evaluate')
parser.add_argument('--output-dir', default='',
type=str, metavar='N', help='output directory')
parser.add_argument('--model-name-or-path', default='tmp-outputs/',
type=str, metavar='N', help='which model to use')
parser.add_argument('--peft-name-or-path', default=None, type=str)
parser.add_argument('--tokenizer-path', default=None)
parser.add_argument('--embedding-path', default=None)
parser.add_argument('--special-token', default=False, type=bool, help='whether to use special token')
parser.add_argument('--zero-shot', default=False, type=bool, help='whether to use zero shot icl')
parser.add_argument('--device', default=0, type=int)
parser.add_argument('--all-num', default=8, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--examples-dir', default='/share/chaofan/code/embedder/evaluate_for_icl/examples', type=str)
parser.add_argument('--eight-special-token', default=False, type=bool)
parser.add_argument('--passage-prompt', default=False, type=bool)
args = parser.parse_args()
base_name: str = args.model_name_or_path.split('/')[-1]
if args.eight_special_token is True:
args.pool_type = 'last_eight'
else:
args.pool_type = 'last'
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last'
os.makedirs(args.output_dir, exist_ok=True)
ALL_NUM = args.all_num
print(args.special_token)
class DenseEncoder(torch.nn.Module):
def __init__(self, device, **kwargs):
super().__init__()
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, use_cache=False)
if args.peft_name_or_path is not None:
peft_name_or_path = args.peft_name_or_path.split(',')
if args.embedding_path is not None:
self.encoder.set_input_embeddings(torch.load(args.embedding_path))
elif args.special_token and os.path.exists(os.path.join(peft_name_or_path[-1], 'embedding', 'emb.pth')):
self.encoder.set_input_embeddings(torch.load(os.path.join(peft_name_or_path[-1], 'embedding', 'emb.pth')))
for peft_path in peft_name_or_path:
self.encoder = PeftModel.from_pretrained(self.encoder, peft_path)
self.encoder = self.encoder.merge_and_unload()
if args.tokenizer_path is not None:
self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
else:
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
self.l2_normalize = True
self.prompt = None
self.prefix = ''
self.suffix = ''
self.gpu_count = torch.cuda.device_count()
self.encoder.half()
self.encoder.eval()
# self.encoder.cuda()
self.device = device
self.encoder = self.encoder.to(device)
self.eight_special_token = args.eight_special_token
if args.eight_special_token:
self.special_tokens = [self.tokenizer.eos_token_id] * 7
else:
self.special_tokens = []
self.passage_prompt= args.passage_prompt
self.batch_size = args.batch_size
# if self.gpu_count > 1:
# self.encoder = torch.nn.DataParallel(self.encoder)
@torch.no_grad()
def encode(self, sentences, **kwargs) -> np.ndarray:
""" Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
"""
input_texts: List[str] = [self.prompt + s for s in sentences]
encoded_embeds = []
batch_size = self.batch_size * self.gpu_count
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10):
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size]
batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens)
# if self.device == 0:
# print(self.tokenizer.decode(batch_dict['input_ids'][0]))
batch_dict = batch_dict.to(self.device)
# batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
if self.l2_normalize:
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
@torch.no_grad()
def encode_queries(self, sentences, **kwargs) -> np.ndarray:
""" Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
"""
input_texts: List[str] = [self.prompt + s for s in sentences]
encoded_embeds = []
batch_size = self.batch_size * self.gpu_count
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10):
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size]
batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens)
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
if self.l2_normalize:
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
@torch.no_grad()
def encode_corpus(self, sentences, **kwargs) -> np.ndarray:
""" Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
"""
if isinstance(sentences[0], str):
input_texts: List[str] = [s for s in sentences]
else:
input_texts: List[str] = [(s['text'] + ' ' + s['title']).strip() for s in sentences]
encoded_embeds = []
batch_size = self.batch_size * self.gpu_count
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10):
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size]
batch_dict = create_batch_dict(self.tokenizer, batch_input_texts, special_tokens=self.special_tokens, passage_prompt=self.passage_prompt)
batch_dict = move_to_cuda(batch_dict)
with torch.cuda.amp.autocast():
outputs: BaseModelOutput = self.encoder(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
if self.l2_normalize:
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
def set_prompt(self, prompt: str):
self.prompt = prompt
def set_prefix(self, prefix: str):
self.prefix = prefix
def set_suffix(self, suffix: str):
self.suffix = suffix
def get_models(model_args: ModelArgs):
embedding_model = AutoFlagModel.from_finetuned(
model_name_or_path,
normalize_embeddings,
use_fp16,
query_instruction_for_retrieval,
query_instruction_format,
devices,
examples_for_task,
examples_instruction_format,
trust_remote_code,
cache_dir
)
cross_encoder = None
if model_args.reranker is not None:
cross_encoder = SentenceTransformerReranker(
model_name_or_path,
peft_path,
use_fp16,
use_bf16,
query_instruction_for_rerank,
query_instruction_format,
passage_instruction_for_rerank,
passage_instruction_format,
cache_dir,
trust_remote_code,
devices
)
return embedding_model, cross_encoder
def main(device, all_pairs):
torch.cuda.set_device(device)
model = DenseEncoder(device)
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
embedding_model, cross_encoder = get_models(model_args)
evaluation = AIRBench(
benchmark_version=eval_args.benchmark_version,
task_types=eval_args.task_types,
domains=eval_args.domains,
languages=eval_args.languages,
splits=eval_args.splits,
cache_dir=eval_args.cache_dir,
)
retriever = EmbeddingModelRetriever(
embedding_model,
search_top_k=eval_args.search_top_k,
corpus_chunk_size=model_args.corpus_chunk_size,
)
if cross_encoder is not None:
reranker = CrossEncoderReranker(
cross_encoder,
rerank_top_k=eval_args.rerank_top_k,
)
ArxivClusteringP2P_FLAG = False
tmp = None
for i, p in enumerate(all_pairs):
if p[1] == 'ArxivClusteringP2P':
ArxivClusteringP2P_FLAG = True
tmp = p
break
if ArxivClusteringP2P_FLAG:
all_pairs.remove(tmp)
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
if ArxivClusteringP2P_FLAG is False:
length = len(all_pairs)
start = device * length // ALL_NUM
if device == ALL_NUM - 1:
end = length
else:
end = (device + 1) * length // ALL_NUM
all_pairs = all_pairs[start: end]
else:
reranker = None
evaluation.run(
retriever,
reranker=reranker,
output_dir=eval_args.output_dir,
overwrite=eval_args.overwrite,
)
if device == ALL_NUM - 1:
all_pairs = [tmp]
else:
all_num = ALL_NUM - 1
length = len(all_pairs)
start = device * length // ALL_NUM
if device == all_num - 1:
end = length
else:
end = (device + 1) * length // all_num
all_pairs = all_pairs[start: end]
for (task_type, task_name) in all_pairs:
task_def: str = get_task_def_by_task_name_and_type(task_name=task_name, task_type=task_type)
prompt: str = get_detailed_instruct(task_def, args.special_token)
model.set_prompt(prompt=prompt)
eg_file_path = f'{args.examples_dir}/{task_name}.csv'
eg_paris = []
if args.zero_shot:
eg_paris = []
else:
df = pd.read_csv(eg_file_path)
for i in range(len(df)):
eg_paris.append((get_detailed_instruct(
task_def,
args.special_token
) + df[df.keys()[0]][i], df[df.keys()[1]][i]))
if args.special_token:
if len(eg_paris) > 0:
prefix = '\n\n'.join(['\n<response>'.join(eg_paris[idx]) for idx in range(len(eg_paris))]) + '\n\n'
else:
prefix = ''
suffix = '\n<response>'
else:
if len(eg_paris) > 0:
prefix = '\n\n'.join(['\nResponse: '.join(eg_paris[idx]) for idx in range(len(eg_paris))]) + '\n\n'
else:
prefix = ''
suffix = '\nResponse:'
model.set_prefix(prefix)
model.set_suffix(suffix)
logger.info('Set prompt: {}'.format(prompt))
# disable l2 normalize for classification tasks, as it achieves slightly better results
if task_type == 'Classification':
logger.info('Set l2_normalize to False for classification task')
model.l2_normalize = False
else:
model.l2_normalize = True
logger.info('Set l2_normalize to {}'.format(model.l2_normalize))
sub_eval = MTEB(tasks=[task_name], task_langs=['en'])
logger.info('Running evaluation for task: {}, type: {}'.format(task_name, task_type))
# eval_splits = ["test"] if "test" in task_cls.description["eval_splits"] else task_cls.description["eval_splits"]
result_flag = False
model.batch_size = args.batch_size
while result_flag is False:
try:
sub_eval.run(
model,
output_folder=args.output_dir
)
result_flag = True
except Exception as e:
model.batch_size -= 4
print(e)
if __name__ == "__main__":
main()
if __name__ == '__main__':
processes = []
multiprocessing.set_start_method('spawn')
random.seed(30)
args.task_types = [t for t in args.task_types if t.strip()]
all_pairs = []
for task_type in args.task_types:
if task_type in tasks_desc.keys():
for task_name in tasks_desc[task_type]:
all_pairs.append((task_type, task_name))
for task_type in tasks_desc.keys():
for v in tasks_desc[task_type]:
if v in args.task_types:
all_pairs.append((task_type, v))
all_pairs = list(set(all_pairs))
random.shuffle(all_pairs)
for i in range(ALL_NUM):
# i = 7
process = multiprocessing.Process(target=main, args=(i,all_pairs,))
processes.append(process)
process.start()
for process in processes:
process.join()

View File

@ -0,0 +1,365 @@
import sys
import torch
import logging
from torch import Tensor
from transformers import PreTrainedTokenizerFast, BatchEncoding
from typing import Mapping, Dict, List
def _setup_logger():
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format)
logger.handlers = [console_handler]
return logger
logger = _setup_logger()
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda(non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
elif isinstance(maybe_tensor, tuple):
return tuple([_move_to_cuda(x) for x in maybe_tensor])
elif isinstance(maybe_tensor, Mapping):
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
else:
return maybe_tensor
return _move_to_cuda(sample)
def pool(last_hidden_states: Tensor,
attention_mask: Tensor,
pool_type: str) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
if pool_type == "avg":
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pool_type == "weightedavg": # position-weighted mean pooling from SGPT (https://arxiv.org/abs/2202.08904)
attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
s = torch.sum(last_hidden * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
emb = s / d
elif pool_type == "cls":
emb = last_hidden[:, 0]
elif pool_type == "last":
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
emb = last_hidden[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
elif pool_type == "last_eight":
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
emb = torch.mean(last_hidden[:, -8:, :], dim=-2)
else:
sys.exit()
else:
raise ValueError(f"pool_type {pool_type} not supported")
return emb
def create_batch_dict(tokenizer: PreTrainedTokenizerFast, input_texts: List[str], max_length: int = 512, special_tokens: list = [], passage_prompt: bool = False) -> BatchEncoding:
passage_suffix = ''
passage_suffix_len = 0
if passage_prompt:
passage_suffix = '\nSummarize the above passage: '
passage_suffix_len = len(tokenizer(passage_suffix, add_special_tokens=False)['input_ids'])
inputs = tokenizer(
input_texts,
max_length=max_length - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
tokenizer('</s>', add_special_tokens=False)['input_ids']) - len(special_tokens) - passage_suffix_len,
return_token_type_ids=False,
truncation=True,
return_tensors=None,
add_special_tokens=False
)
input_texts = tokenizer.batch_decode(inputs['input_ids'])
if len(special_tokens) > 0:
for i in range(len(input_texts)):
input_texts[i] = input_texts[i] + passage_suffix + tokenizer.eos_token * 7
else:
for i in range(len(input_texts)):
input_texts[i] = input_texts[i] + passage_suffix
return tokenizer(
input_texts,
max_length=max_length,
padding=True,
pad_to_multiple_of=8,
return_token_type_ids=False,
truncation=True,
return_tensors='pt'
)
def create_batch_query_dict(tokenizer: PreTrainedTokenizerFast, prefix: str, suffix: str, input_texts: List[str], max_length: int = 512, special_tokens: list = []) -> BatchEncoding:
inputs = tokenizer(
input_texts,
max_length=max_length - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']) - len(special_tokens),
return_token_type_ids=False,
truncation=True,
return_tensors=None,
add_special_tokens=False
)
prefix_ids = tokenizer(prefix, add_special_tokens=False)['input_ids']
suffix_ids = tokenizer(suffix, add_special_tokens=False)['input_ids']
new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length) // 8 * 8 + 8
input_texts = tokenizer.batch_decode(inputs['input_ids'])
for i in range(len(input_texts)):
if len(special_tokens) > 0:
input_texts[i] = prefix + input_texts[i] + suffix + tokenizer.eos_token * 7
else:
input_texts[i] = prefix + input_texts[i] + suffix
return tokenizer(
input_texts,
max_length=new_max_length,
padding=True,
pad_to_multiple_of=8,
return_token_type_ids=False,
truncation=True,
return_tensors='pt'
)
def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str:
if task_type in ['STS']:
return "Retrieve semantically similar text."
if task_type in ['Summarization']:
return "Given a news summary, retrieve other semantically similar summaries."
if task_type in ['BitextMining']:
return "Retrieve parallel sentences."
if task_type in ['Classification']:
task_name_to_instruct: Dict[str, str] = {
'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual.',
'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment.',
'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category.',
'Banking77Classification': 'Given a online banking query, find the corresponding intents.',
'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.',
'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset.',
'MassiveIntentClassification': 'Given a user utterance as query, find the user intents.',
'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios.',
'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation.',
'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation.',
'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic.',
'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral.',
# C-MTEB eval instructions
'TNews': 'Classify the fine-grained category of the given news title.',
'IFlyTek': 'Given an App description text, find the appropriate fine-grained category.',
'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative.',
'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative.',
'OnlineShopping': 'Classify the customer review for online shopping into positive or negative.',
'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative.',
}
return task_name_to_instruct[task_name]
if task_type in ['Clustering']:
task_name_to_instruct: Dict[str, str] = {
'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts.',
'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles.',
'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts.',
'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles.',
'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts.',
'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles.',
'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles.',
'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts.',
'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles.',
'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs.',
'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles.',
# C-MTEB eval instructions
'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles.',
'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts.',
'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles.',
'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents.',
}
return task_name_to_instruct[task_name]
if task_type in ['Reranking', 'PairClassification']:
task_name_to_instruct: Dict[str, str] = {
'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum.',
'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history.',
'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers.',
'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum.',
'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum.',
'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet.',
'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet.',
# C-MTEB eval instructions
'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
'Ocnli': 'Retrieve semantically similar text.',
'Cmnli': 'Retrieve semantically similar text.',
}
return task_name_to_instruct[task_name]
if task_type in ['Retrieval']:
if task_name.lower().startswith('cqadupstack'):
return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.'
task_name_to_instruct: Dict[str, str] = {
'ArguAna': 'Given a claim, find documents that refute the claim.',
'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim.',
'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia.',
'FEVER': 'Given a claim, retrieve documents that support or refute the claim.',
'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question.',
'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question.',
'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.',
'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question.',
'NQ': 'Given a question, retrieve Wikipedia passages that answer the question.',
'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question.',
'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.',
'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim.',
'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.',
'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query.',
# C-MTEB eval instructions
'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query.',
'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question.',
'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products.',
'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question.',
'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos.',
}
# add lower case keys to match some beir names
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
# other cases where lower case match still doesn't work
task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
# for miracl evaluation
task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question.'
return task_name_to_instruct[task_name]
raise ValueError(f"No instruction config for task {task_name} with type {task_type}")
def get_detailed_instruct(task_description: str, special_token: bool) -> str:
if not task_description:
return ''
if special_token:
return f'<instruct>{task_description}\n<query>'
else:
return f'Instruct: {task_description}\nQuery: '
tasks_desc = {
'Retrieval': [
'ArguAna',
'ClimateFEVER',
'DBPedia',
'FEVER',
'FiQA2018',
'HotpotQA',
'MSMARCO',
'NFCorpus',
'NQ',
'QuoraRetrieval',
'SCIDOCS',
'SciFact',
'Touche2020',
'TRECCOVID',
'CQADupstackAndroidRetrieval.py',
'CQADupstackEnglishRetrieval.py',
'CQADupstackGamingRetrieval',
'CQADupstackGisRetrieval',
'CQADupstackMathematicaRetrieval',
'CQADupstackPhysicsRetrieval',
'CQADupstackProgrammersRetrieval',
'CQADupstackStatsRetrieval',
'CQADupstackTexRetrieval',
'CQADupstackUnixRetrieval',
'CQADupstackWebmastersRetrieval',
'CQADupstackWordpressRetrieval'
],
'Classification': [
# 12
'AmazonCounterfactualClassification',
'AmazonPolarityClassification',
'AmazonReviewsClassification',
'Banking77Classification',
'EmotionClassification',
'ImdbClassification',
'MassiveIntentClassification',
'MassiveScenarioClassification',
'MTOPDomainClassification',
'MTOPIntentClassification',
'ToxicConversationsClassification',
'TweetSentimentExtractionClassification',
],
'Clustering': [
# 11
'ArxivClusteringP2P',
'ArxivClusteringS2S',
'BiorxivClusteringP2P',
'BiorxivClusteringS2S',
'MedrxivClusteringP2P',
'MedrxivClusteringS2S',
'RedditClustering',
'RedditClusteringP2P',
'StackExchangeClustering',
'StackExchangeClusteringP2P',
'TwentyNewsgroupsClustering',
],
'PairClassification': [
# 3
'SprintDuplicateQuestions',
'TwitterSemEval2015',
'TwitterURLCorpus',
],
'Reranking': [
# 4
'AskUbuntuDupQuestions',
'MindSmallReranking',
'SciDocsRR',
'StackOverflowDupQuestions',
],
'STS': [
# 10
'BIOSSES',
'SICK-R',
'STS12',
'STS13',
'STS14',
'STS15',
'STS16',
'STS17',
'STS22',
'STSBenchmark',
],
'Summarization': [
# 1
'SummEval',
]
}

View File

Can't render this file because it contains an unexpected character in line 4 and column 86.

View File

@ -0,0 +1,216 @@
import sys
import torch
def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str:
if task_type in ['STS']:
return "Retrieve semantically similar text."
if task_type in ['Summarization']:
return "Given a news summary, retrieve other semantically similar summaries."
if task_type in ['BitextMining']:
return "Retrieve parallel sentences."
if task_type in ['Classification']:
task_name_to_instruct: Dict[str, str] = {
'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual.',
'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment.',
'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category.',
'Banking77Classification': 'Given a online banking query, find the corresponding intents.',
'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.',
'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset.',
'MassiveIntentClassification': 'Given a user utterance as query, find the user intents.',
'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios.',
'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation.',
'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation.',
'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic.',
'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral.',
# C-MTEB eval instructions
'TNews': 'Classify the fine-grained category of the given news title.',
'IFlyTek': 'Given an App description text, find the appropriate fine-grained category.',
'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative.',
'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative.',
'OnlineShopping': 'Classify the customer review for online shopping into positive or negative.',
'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative.',
}
return task_name_to_instruct[task_name]
if task_type in ['Clustering']:
task_name_to_instruct: Dict[str, str] = {
'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts.',
'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles.',
'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts.',
'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles.',
'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts.',
'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles.',
'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles.',
'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts.',
'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles.',
'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs.',
'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles.',
# C-MTEB eval instructions
'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles.',
'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts.',
'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles.',
'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents.',
}
return task_name_to_instruct[task_name]
if task_type in ['Reranking', 'PairClassification']:
task_name_to_instruct: Dict[str, str] = {
'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum.',
'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history.',
'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers.',
'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum.',
'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum.',
'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet.',
'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet.',
# C-MTEB eval instructions
'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question.',
'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
'Ocnli': 'Retrieve semantically similar text.',
'Cmnli': 'Retrieve semantically similar text.',
}
return task_name_to_instruct[task_name]
if task_type in ['Retrieval']:
if task_name.lower().startswith('cqadupstack'):
return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.'
task_name_to_instruct: Dict[str, str] = {
'ArguAna': 'Given a claim, find documents that refute the claim.',
'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim.',
'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia.',
'FEVER': 'Given a claim, retrieve documents that support or refute the claim.',
'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question.',
'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question.',
'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.',
'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question.',
'NQ': 'Given a question, retrieve Wikipedia passages that answer the question.',
'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question.',
'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.',
'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim.',
'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.',
'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query.',
# C-MTEB eval instructions
'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query.',
'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question.',
'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question.',
'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question.',
'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products.',
'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question.',
'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos.',
}
# add lower case keys to match some beir names
task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()})
# other cases where lower case match still doesn't work
task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID']
task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER']
task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia']
task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020']
task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018']
task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval']
# for miracl evaluation
task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question.'
return task_name_to_instruct[task_name]
raise ValueError(f"No instruction config for task {task_name} with type {task_type}")
tasks_desc = {
'Retrieval': [
'ArguAna',
'ClimateFEVER',
'DBPedia',
'FEVER',
'FiQA2018',
'HotpotQA',
'MSMARCO',
'NFCorpus',
'NQ',
'QuoraRetrieval',
'SCIDOCS',
'SciFact',
'Touche2020',
'TRECCOVID',
'CQADupstackAndroidRetrieval',
'CQADupstackEnglishRetrieval',
'CQADupstackGamingRetrieval',
'CQADupstackGisRetrieval',
'CQADupstackMathematicaRetrieval',
'CQADupstackPhysicsRetrieval',
'CQADupstackProgrammersRetrieval',
'CQADupstackStatsRetrieval',
'CQADupstackTexRetrieval',
'CQADupstackUnixRetrieval',
'CQADupstackWebmastersRetrieval',
'CQADupstackWordpressRetrieval'
],
'Classification': [
# 12
'AmazonCounterfactualClassification',
'AmazonPolarityClassification',
'AmazonReviewsClassification',
'Banking77Classification',
'EmotionClassification',
'ImdbClassification',
'MassiveIntentClassification',
'MassiveScenarioClassification',
'MTOPDomainClassification',
'MTOPIntentClassification',
'ToxicConversationsClassification',
'TweetSentimentExtractionClassification',
],
'Clustering': [
# 11
'ArxivClusteringP2P',
'ArxivClusteringS2S',
'BiorxivClusteringP2P',
'BiorxivClusteringS2S',
'MedrxivClusteringP2P',
'MedrxivClusteringS2S',
'RedditClustering',
'RedditClusteringP2P',
'StackExchangeClustering',
'StackExchangeClusteringP2P',
'TwentyNewsgroupsClustering',
],
'PairClassification': [
# 3
'SprintDuplicateQuestions',
'TwitterSemEval2015',
'TwitterURLCorpus',
],
'Reranking': [
# 4
'AskUbuntuDupQuestions',
'MindSmallReranking',
'SciDocsRR',
'StackOverflowDupQuestions',
],
'STS': [
# 10
'BIOSSES',
'SICK-R',
'STS12',
'STS13',
'STS14',
'STS15',
'STS16',
'STS17',
'STS22',
'STSBenchmark',
],
'Summarization': [
# 1
'SummEval',
]
}

View File

@ -20,5 +20,8 @@ setup(
'accelerate>=0.20.1',
'sentence_transformers',
'peft',
'beir',
'deepspeed',
'flash-attn'
],
)