From 00a42ccd4f4efe9960ea3e8b73a49bf96bf57b49 Mon Sep 17 00:00:00 2001 From: cfli <545999961@qq.com> Date: Thu, 24 Oct 2024 15:48:21 +0800 Subject: [PATCH] eval beir --- FlagEmbedding/evaluation/beir/__main__.py | 96 +++++ FlagEmbedding/evaluation/beir/run.sh | 12 + .../evaluation/beir/utils/arguments.py | 1 + .../evaluation/beir/utils/data_loader.py | 117 ++++-- .../evaluation/beir/utils/evaluator.py | 188 ++++++--- FlagEmbedding/evaluation/mteb/__main__.py | 44 ++ FlagEmbedding/evaluation/mteb/evaluate.py | 390 ++++++++++++++---- FlagEmbedding/evaluation/mteb/utils.py | 365 ++++++++++++++++ .../air-bench/long-doc/arxiv-gemini.jsonl | 0 .../air-bench/long-doc/arxiv-gpt3.jsonl | 0 .../air-bench/long-doc/arxiv-llama2.jsonl | 0 .../air-bench/long-doc/arxiv-llm-survey.jsonl | 0 ...rief-history-of-time_stephen-hawking.jsonl | 0 .../book-origin-of-species_darwin.jsonl | 0 .../healthcare-pubmed_100k-200k_1.jsonl | 0 .../healthcare-pubmed_100k-200k_2.jsonl | 0 .../healthcare-pubmed_100k-200k_3.jsonl | 0 .../healthcare-pubmed_30k-40k_10-merged.jsonl | 0 .../healthcare-pubmed_40k-50k_5-merged.jsonl | 0 .../long-doc/law-lex_files_300k-400k.jsonl | 0 .../long-doc/law-lex_files_400k-500k.jsonl | 0 .../long-doc/law-lex_files_500k-600k.jsonl | 0 .../long-doc/law-lex_files_600k-700k.jsonl | 0 .../utils}/examples/air-bench/qa/arxiv.jsonl | 0 .../examples/air-bench/qa/finance.jsonl | 0 .../examples/air-bench/qa/healthcare.jsonl | 0 .../utils}/examples/air-bench/qa/law.jsonl | 0 .../examples/air-bench/qa/msmarco.jsonl | 0 .../utils}/examples/air-bench/qa/news.jsonl | 0 .../utils}/examples/air-bench/qa/web.jsonl | 0 .../utils}/examples/air-bench/qa/wiki.jsonl | 0 .../AmazonCounterfactualClassification.csv | 0 .../mteb/AmazonPolarityClassification.csv | 0 .../mteb/AmazonReviewsClassification.csv | 0 .../mteb/utils}/examples/mteb/ArguAna.csv | 0 .../examples/mteb/ArxivClusteringP2P.csv | 0 .../examples/mteb/ArxivClusteringS2S.csv | 0 .../examples/mteb/AskUbuntuDupQuestions.csv | 0 .../mteb/utils}/examples/mteb/BIOSSES.csv | 0 .../examples/mteb/Banking77Classification.csv | 0 .../examples/mteb/BiorxivClusteringP2P.csv | 0 .../examples/mteb/BiorxivClusteringS2S.csv | 0 .../mteb/utils}/examples/mteb/CQADupstack.csv | 0 .../examples/mteb/CQADupstackRetrieval.csv | 0 .../utils}/examples/mteb/ClimateFEVER.csv | 0 .../mteb/utils}/examples/mteb/DBPedia.csv | 0 .../examples/mteb/EmotionClassification.csv | 0 .../mteb/utils}/examples/mteb/FEVER.csv | 0 .../mteb/utils}/examples/mteb/FiQA2018.csv | 0 .../mteb/utils}/examples/mteb/HotpotQA.csv | 0 .../examples/mteb/ImdbClassification.csv | 0 .../mteb/utils}/examples/mteb/MSMARCO.csv | 0 .../mteb/MTOPDomainClassification.csv | 0 .../mteb/MTOPIntentClassification.csv | 0 .../mteb/MassiveIntentClassification.csv | 0 .../mteb/MassiveScenarioClassification.csv | 0 .../examples/mteb/MedrxivClusteringP2P.csv | 0 .../examples/mteb/MedrxivClusteringS2S.csv | 0 .../examples/mteb/MindSmallReranking.csv | 0 .../mteb/utils}/examples/mteb/NFCorpus.csv | 0 .../mteb/utils}/examples/mteb/NQ.csv | 0 .../utils}/examples/mteb/QuoraRetrieval.csv | 0 .../utils}/examples/mteb/RedditClustering.csv | 0 .../examples/mteb/RedditClusteringP2P.csv | 0 .../mteb/utils}/examples/mteb/SCIDOCS.csv | 0 .../mteb/utils}/examples/mteb/SICK-R.csv | 0 .../mteb/utils}/examples/mteb/STS12.csv | 0 .../mteb/utils}/examples/mteb/STS13.csv | 0 .../mteb/utils}/examples/mteb/STS14.csv | 0 .../mteb/utils}/examples/mteb/STS15.csv | 0 .../mteb/utils}/examples/mteb/STS16.csv | 0 .../mteb/utils}/examples/mteb/STS17.csv | 0 .../mteb/utils}/examples/mteb/STS22.csv | 0 .../utils}/examples/mteb/STSBenchmark.csv | 0 .../mteb/utils}/examples/mteb/SciDocsRR.csv | 0 .../mteb/utils}/examples/mteb/SciFact.csv | 0 .../mteb/SprintDuplicateQuestions.csv | 0 .../examples/mteb/StackExchangeClustering.csv | 0 .../mteb/StackExchangeClusteringP2P.csv | 0 .../mteb/StackOverflowDupQuestions.csv | 0 .../mteb/utils}/examples/mteb/SummEval.csv | 0 .../mteb/utils}/examples/mteb/TRECCOVID.csv | 0 .../mteb/utils}/examples/mteb/Touche2020.csv | 0 .../mteb/ToxicConversationsClassification.csv | 0 ...TweetSentimentExtractionClassification.csv | 0 .../mteb/TwentyNewsgroupsClustering.csv | 0 .../examples/mteb/TwitterSemEval2015.csv | 0 .../utils}/examples/mteb/TwitterURLCorpus.csv | 0 .../evaluation/mteb/utils/prompts.py | 216 ++++++++++ setup.py | 3 + 90 files changed, 1272 insertions(+), 160 deletions(-) create mode 100644 FlagEmbedding/evaluation/beir/run.sh create mode 100644 FlagEmbedding/evaluation/mteb/utils.py rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/arxiv-gemini.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/arxiv-gpt3.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/arxiv-llama2.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/arxiv-llm-survey.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/book-a-brief-history-of-time_stephen-hawking.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/book-origin-of-species_darwin.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_1.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_2.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_3.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/healthcare-pubmed_30k-40k_10-merged.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/healthcare-pubmed_40k-50k_5-merged.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/law-lex_files_300k-400k.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/law-lex_files_400k-500k.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/law-lex_files_500k-600k.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/long-doc/law-lex_files_600k-700k.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/arxiv.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/finance.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/healthcare.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/law.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/msmarco.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/news.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/web.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/air-bench/qa/wiki.jsonl (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/AmazonCounterfactualClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/AmazonPolarityClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/AmazonReviewsClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/ArguAna.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/ArxivClusteringP2P.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/ArxivClusteringS2S.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/AskUbuntuDupQuestions.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/BIOSSES.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/Banking77Classification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/BiorxivClusteringP2P.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/BiorxivClusteringS2S.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/CQADupstack.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/CQADupstackRetrieval.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/ClimateFEVER.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/DBPedia.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/EmotionClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/FEVER.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/FiQA2018.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/HotpotQA.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/ImdbClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MSMARCO.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MTOPDomainClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MTOPIntentClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MassiveIntentClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MassiveScenarioClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MedrxivClusteringP2P.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MedrxivClusteringS2S.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/MindSmallReranking.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/NFCorpus.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/NQ.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/QuoraRetrieval.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/RedditClustering.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/RedditClusteringP2P.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/SCIDOCS.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/SICK-R.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS12.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS13.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS14.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS15.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS16.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS17.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STS22.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/STSBenchmark.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/SciDocsRR.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/SciFact.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/SprintDuplicateQuestions.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/StackExchangeClustering.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/StackExchangeClusteringP2P.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/StackOverflowDupQuestions.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/SummEval.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/TRECCOVID.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/Touche2020.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/ToxicConversationsClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/TweetSentimentExtractionClassification.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/TwentyNewsgroupsClustering.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/TwitterSemEval2015.csv (100%) rename {FlagEmbedding_old/llm_dense_retriever => FlagEmbedding/evaluation/mteb/utils}/examples/mteb/TwitterURLCorpus.csv (100%) create mode 100644 FlagEmbedding/evaluation/mteb/utils/prompts.py diff --git a/FlagEmbedding/evaluation/beir/__main__.py b/FlagEmbedding/evaluation/beir/__main__.py index e69de29..f050502 100644 --- a/FlagEmbedding/evaluation/beir/__main__.py +++ b/FlagEmbedding/evaluation/beir/__main__.py @@ -0,0 +1,96 @@ +from transformers import HfArgumentParser + +from FlagEmbedding import FlagAutoModel, FlagAutoReranker +from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker + + +from utils.arguments import BEIREvalArgs +from utils.data_loader import BEIRDataLoader +from utils.evaluator import BEIREvaluator + + +def get_models(model_args: AbsModelArgs): + retriever = FlagAutoModel.from_finetuned( + model_name_or_path=model_args.embedder_name_or_path, + normalize_embeddings=model_args.normalize_embeddings, + use_fp16=model_args.use_fp16, + query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, + query_instruction_format=model_args.query_instruction_format_for_retrieval, + devices=model_args.devices, + examples_for_task=model_args.examples_for_task, + examples_instruction_format=model_args.examples_instruction_format, + trust_remote_code=model_args.trust_remote_code, + cache_dir=model_args.cache_dir + ) + reranker = None + if model_args.reranker_name_or_path is not None: + reranker = FlagAutoReranker.from_finetuned( + model_name_or_path=model_args.reranker_name_or_path, + peft_path=model_args.reranker_peft_path, + use_fp16=model_args.use_fp16, + use_bf16=model_args.use_bf16, + query_instruction_for_rerank=model_args.query_instruction_for_rerank, + query_instruction_format=model_args.query_instruction_format_for_rerank, + passage_instruction_for_rerank=model_args.passage_instruction_for_rerank, + passage_instruction_format=model_args.passage_instruction_format_for_rerank, + cache_dir=model_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + devices=model_args.devices, + normalize=model_args.normalize, + prompt=model_args.prompt, + cutoff_layers=model_args.cutoff_layers, + compress_layers=model_args.compress_layers, + compress_ratio=model_args.compress_ratio, + ) + return retriever, reranker + + +def main(): + parser = HfArgumentParser([AbsModelArgs, BEIREvalArgs]) + model_args, eval_args = parser.parse_args_into_dataclasses() + model_args: AbsModelArgs + eval_args: BEIREvalArgs + + retriever, reranker = get_models(model_args) + retriever = AbsEmbedder( + retriever, + search_top_k=eval_args.search_top_k, + ) + + if reranker is not None: + reranker = AbsReranker( + reranker, + rerank_top_k=eval_args.rerank_top_k, + ) + else: + reranker = None + + for dataset_name in eval_args.dataset_names: + + data_loader = BEIRDataLoader( + dataset_dir = eval_args.dataset_dir, + cache_dir = eval_args.cache_path, + dataset_name=dataset_name + ) + + evaluation = BEIREvaluator( + data_loader=data_loader, + overwrite=eval_args.overwrite, + ) + + evaluation( + splits=eval_args.splits.split(), + search_results_save_dir=eval_args.output_dir, + retriever=retriever, + reranker=reranker, + corpus_embd_save_dir=eval_args.corpus_embd_save_dir, + retriever_batch_size=model_args.retriever_batch_size, + reranker_batch_size=model_args.reranker_batch_size, + retriever_query_max_length=model_args.retriever_query_max_length, + retriever_passage_max_length=model_args.retriever_passage_max_length, + reranker_max_length=model_args.reranker_max_length, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/FlagEmbedding/evaluation/beir/run.sh b/FlagEmbedding/evaluation/beir/run.sh new file mode 100644 index 0000000..013b160 --- /dev/null +++ b/FlagEmbedding/evaluation/beir/run.sh @@ -0,0 +1,12 @@ +python __main__.py \ +--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR \ +--embedder_name_or_path BAAI/bge-large-en-v1.5 \ +--reranker_name_or_path BAAI/bge-reranker-large \ +--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: " \ +--use_fp16 True \ +--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \ +--cache_dir /share/shared_models \ +--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR_passage_embds \ +--reranker_max_length 512 \ +--dataset_names nfcorpus fiqa cqadupstack + diff --git a/FlagEmbedding/evaluation/beir/utils/arguments.py b/FlagEmbedding/evaluation/beir/utils/arguments.py index 2ffae38..9090a27 100644 --- a/FlagEmbedding/evaluation/beir/utils/arguments.py +++ b/FlagEmbedding/evaluation/beir/utils/arguments.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import List from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs diff --git a/FlagEmbedding/evaluation/beir/utils/data_loader.py b/FlagEmbedding/evaluation/beir/utils/data_loader.py index e7649ca..3dbb1fe 100644 --- a/FlagEmbedding/evaluation/beir/utils/data_loader.py +++ b/FlagEmbedding/evaluation/beir/utils/data_loader.py @@ -15,7 +15,7 @@ from FlagEmbedding.abc.evaluation.data_loader import AbsDataLoader logger = logging.getLogger(__name__) -class MSMARCODataLoader(AbsDataLoader): +class BEIRDataLoader(AbsDataLoader): def __init__( self, dataset_dir: str, # the dataset dir to load from @@ -32,16 +32,15 @@ class MSMARCODataLoader(AbsDataLoader): self.split = 'test' if dataset_name == 'msmarco': self.split = 'dev' - os.makedirs(self.dataset_dir, exist_ok=True) - if dataset_name != 'cqadupstack': + os.makedirs(self.dataset_dir, exist_ok=True) qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split)) corpus_path = os.path.join(self.dataset_dir, 'corpus.json') queries_path = os.path.join(self.dataset_dir, 'queries-{split}.json'.format(split=self.split)) queries, corpus, rels = {}, {}, {} if not os.path.exists(corpus_path): dataset = datasets.load_dataset( - 'BeIR/{d}'.format(d=data_name), + 'BeIR/{d}'.format(d=dataset_name), 'corpus', trust_remote_code=True, cache_dir=os.getenv('HF_HUB_CACHE', cache_dir), @@ -55,7 +54,7 @@ class MSMARCODataLoader(AbsDataLoader): if not os.path.exists(queries_path): dataset = datasets.load_dataset( - 'BeIR/{d}'.format(d=data_name), + 'BeIR/{d}'.format(d=dataset_name), 'queries', trust_remote_code=True, cache_dir=os.getenv('HF_HUB_CACHE', cache_dir), @@ -69,7 +68,7 @@ class MSMARCODataLoader(AbsDataLoader): if not os.path.exists(qrels_path): dataset = datasets.load_dataset( - 'BeIR/{d}-qrels'.format(d=data_name), + 'BeIR/{d}-qrels'.format(d=dataset_name), split=self.split, trust_remote_code=True, cache_dir=os.getenv('HF_HUB_CACHE', cache_dir), @@ -93,10 +92,8 @@ class MSMARCODataLoader(AbsDataLoader): if not os.path.exists(corpus_path) or not os.path.exists(queries_path) or not os.path.exists(qrels_path): full_path = os.path.join(data_path, sub_dataset_name) corpus, queries, qrels = GenericDataLoader(data_folder=full_path).load(split="test") - for k in corpus_data.keys(): - corpus[k] = { - (corpus[k]['title'] + ' ' + corpus[k]['text']).strip() - } + for k in corpus.keys(): + corpus[k] = (corpus[k]['title'] + ' ' + corpus[k]['text']).strip() with open(corpus_path, 'w') as f: json.dump(corpus, f) @@ -105,25 +102,40 @@ class MSMARCODataLoader(AbsDataLoader): json.dump(queries, f) with open(qrels_path, 'w') as f: - json.dump(rels, f) + json.dump(qrels, f) - def load_qrels(self): - if self.sub_dataset_name is None: + def load_qrels(self, sub_dataset_name: str = None, split: str = None): + if sub_dataset_name is None and split is not None: + sub_dataset_name = split.split('-') + if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0: + sub_dataset_name = None + else: + sub_dataset_name = sub_dataset_name[0].strip() + if sub_dataset_name is None: qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split)) rels = json.load(open(qrels_path)) return datasets.DatasetDict(rels) else: - rels_list = [] - for sub_dataset_name in self.sub_dataset_names: - qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split)) - rels = json.load(open(qrels_path)) - rels_list.append(datasets.DatasetDict(rels)) - return rels_list + # rels_list = [] + # for sub_dataset_name in self.sub_dataset_names: + # qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split)) + # rels = json.load(open(qrels_path)) + # rels_list.append(datasets.DatasetDict(rels)) + # return rels_list + qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split)) + rels = json.load(open(qrels_path)) + return datasets.DatasetDict(rels) - def load_corpus(self): - if self.sub_dataset_name is None: + def load_corpus(self, sub_dataset_name: str = None, split: str = None): + if sub_dataset_name is None and split is not None: + sub_dataset_name = split.split('-') + if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0: + sub_dataset_name = None + else: + sub_dataset_name = sub_dataset_name[0].strip() + if sub_dataset_name is None: corpus_path = os.path.join(self.dataset_dir, 'corpus.json') corpus = json.load(open(corpus_path)) for k in corpus.keys(): @@ -132,15 +144,28 @@ class MSMARCODataLoader(AbsDataLoader): } return datasets.DatasetDict(corpus) else: - corpus_list = [] - for sub_dataset_name in self.sub_dataset_names: - corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json') - corpus = json.load(open(corpus_path)) - corpus_list.append(datasets.DatasetDict(corpus)) - return corpus_list + corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json') + corpus = json.load(open(corpus_path)) + for k in corpus.keys(): + corpus[k] = { + 'text': corpus[k] + } + return datasets.DatasetDict(corpus) + # corpus_list = [] + # for sub_dataset_name in self.sub_dataset_names: + # corpus_path = os.path.join(self.dataset_dir, sub_dataset_name, 'corpus.json') + # corpus = json.load(open(corpus_path)) + # corpus_list.append(datasets.DatasetDict(corpus)) + # return corpus_list - def load_queries(self, split='dev'): - if self.sub_dataset_name is None: + def load_queries(self, sub_dataset_name: str = None, split: str = None): + if sub_dataset_name is None and split is not None: + sub_dataset_name = split.split('-') + if len(sub_dataset_name) == 1 or len(sub_dataset_name[0].strip()) == 0: + sub_dataset_name = None + else: + sub_dataset_name = sub_dataset_name[0].strip() + if sub_dataset_name is None: queries_path = os.path.join(self.dataset_dir, 'queries-{split}.json'.format(split=self.split)) qrels_path = os.path.join(self.dataset_dir, 'qrels-{split}.json'.format(split=self.split)) queries = json.load(open(queries_path)) @@ -151,15 +176,25 @@ class MSMARCODataLoader(AbsDataLoader): new_queries[k] = queries[k] return datasets.DatasetDict(new_queries) else: - queries_list = [] - for sub_dataset_name in self.sub_dataset_names: - queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split)) - qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split)) - queries = json.load(open(queries_path)) - rels = json.load(open(qrels_path)) - new_queries = {} - for k in queries.keys(): - if k in rels.keys(): - new_queries[k] = queries[k] - queries_list.append(datasets.DatasetDict(new_queries)) - return queries_list + queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split)) + qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split)) + queries = json.load(open(queries_path)) + print(qrels_path) + rels = json.load(open(qrels_path)) + new_queries = {} + for k in queries.keys(): + if k in rels.keys(): + new_queries[k] = queries[k] + return datasets.DatasetDict(new_queries) + # queries_list = [] + # for sub_dataset_name in self.sub_dataset_names: + # queries_path = os.path.join(self.dataset_dir, sub_dataset_name, 'queries-{split}.json'.format(split=self.split)) + # qrels_path = os.path.join(self.dataset_dir, sub_dataset_name, 'qrels-{split}.json'.format(split=self.split)) + # queries = json.load(open(queries_path)) + # rels = json.load(open(qrels_path)) + # new_queries = {} + # for k in queries.keys(): + # if k in rels.keys(): + # new_queries[k] = queries[k] + # queries_list.append(datasets.DatasetDict(new_queries)) + # return queries_list diff --git a/FlagEmbedding/evaluation/beir/utils/evaluator.py b/FlagEmbedding/evaluation/beir/utils/evaluator.py index cadd860..c9ee602 100644 --- a/FlagEmbedding/evaluation/beir/utils/evaluator.py +++ b/FlagEmbedding/evaluation/beir/utils/evaluator.py @@ -6,6 +6,8 @@ import pandas as pd from typing import Dict, Optional, List, Union from FlagEmbedding.abc.evaluation.evaluator import AbsEvaluator +from FlagEmbedding.abc.evaluation.data_loader import AbsDataLoader +from FlagEmbedding.abc.evaluation.searcher import AbsEmbedder, AbsReranker class BEIREvaluator(AbsEvaluator): def __init__( @@ -33,50 +35,55 @@ class BEIREvaluator(AbsEvaluator): ): dataset_name = self.data_loader.dataset_name sub_dataset_names = self.data_loader.sub_dataset_names - if sub_dataset_name + split = self.data_loader.split if isinstance(splits, str): splits = [splits] # Retrieval Stage no_reranker_search_results_save_dir = os.path.join( - search_results_save_dir, str(retriever), "NoReranker" + search_results_save_dir, str(retriever), "NoReranker", dataset_name ) os.makedirs(no_reranker_search_results_save_dir, exist_ok=True) + if corpus_embd_save_dir is not None: corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, dataset_name) flag = False if sub_dataset_names is None: split_no_reranker_search_results_save_path = os.path.join( - no_reranker_search_results_save_dir, f"{dataset_name}.json" + no_reranker_search_results_save_dir, f"{split}.json" ) if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite: flag = True - break else: - for sub_dataset_name + for sub_dataset_name in sub_dataset_names: + split_no_reranker_search_results_save_path = os.path.join( + no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json" + ) + if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite: + flag = True + break no_reranker_search_results_dict = {} if flag: - corpus = self.data_loader.load_corpus() + if sub_dataset_names is None: + corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_names) - queries_dict = { - split: self.data_loader.load_queries(split=split) - for split in splits - } + queries_dict = { + split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_names) + } - all_queries = {} - for _, split_queries in queries_dict.items(): - all_queries.update(split_queries) + all_queries = {} + for _, split_queries in queries_dict.items(): + all_queries.update(split_queries) - all_no_reranker_search_results = retriever( - corpus=corpus, - queries=all_queries, - corpus_embd_save_dir=corpus_embd_save_dir, - batch_size=retriever_batch_size, - query_max_length=retriever_query_max_length, - passage_max_length=retriever_passage_max_length, - **kwargs, - ) + all_no_reranker_search_results = retriever( + corpus=corpus, + queries=all_queries, + corpus_embd_save_dir=corpus_embd_save_dir, + batch_size=retriever_batch_size, + query_max_length=retriever_query_max_length, + passage_max_length=retriever_passage_max_length, + **kwargs, + ) - for split in splits: split_queries = queries_dict[split] no_reranker_search_results_dict[split] = { qid: all_no_reranker_search_results[qid] for qid in split_queries @@ -93,8 +100,47 @@ class BEIREvaluator(AbsEvaluator): split=split, dataset_dir=self.dataset_dir, ) + else: + for sub_dataset_name in sub_dataset_names: + corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_name) + + queries_dict = { + split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_name) + } + + all_queries = {} + for _, split_queries in queries_dict.items(): + all_queries.update(split_queries) + + all_no_reranker_search_results = retriever( + corpus=corpus, + queries=all_queries, + corpus_embd_save_dir=None if corpus_embd_save_dir is None else os.path.join(corpus_embd_save_dir, sub_dataset_name), + batch_size=retriever_batch_size, + query_max_length=retriever_query_max_length, + passage_max_length=retriever_passage_max_length, + **kwargs, + ) + + split_queries = queries_dict[split] + no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"] = { + qid: all_no_reranker_search_results[qid] for qid in split_queries + } + split_no_reranker_search_results_save_path = os.path.join( + no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json" + ) + + self.save_search_results( + model_name=str(retriever), + reranker_name="NoReranker", + search_results=no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"], + output_path=split_no_reranker_search_results_save_path, + split=f"{sub_dataset_name}-{split}", + dataset_dir=self.dataset_dir, + ) + else: - for split in splits: + if sub_dataset_names is None: split_no_reranker_search_results_save_path = os.path.join( no_reranker_search_results_save_dir, f"{split}.json" ) @@ -107,47 +153,93 @@ class BEIREvaluator(AbsEvaluator): split=split, ) no_reranker_search_results_dict[split] = search_results + else: + for sub_dataset_name in sub_dataset_names: + split_no_reranker_search_results_save_path = os.path.join( + no_reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json" + ) + data_info, search_results = self.load_search_results(split_no_reranker_search_results_save_path) + + self.check_data_info( + data_info=data_info, + model_name=str(retriever), + reranker_name="NoReranker", + split=f"{sub_dataset_name}-{split}", + ) + no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"] = search_results retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir) self.output_eval_results_to_json(retriever_eval_results, os.path.join(no_reranker_search_results_save_dir, 'eval.json')) # Reranking Stage if reranker is not None: reranker_search_results_save_dir = os.path.join( - search_results_save_dir, str(retriever), str(reranker) + search_results_save_dir, str(retriever), str(reranker), dataset_name ) os.makedirs(reranker_search_results_save_dir, exist_ok=True) - corpus = self.data_loader.load_corpus() + if sub_dataset_names is None: + corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_names) - queries_dict = { - split: self.data_loader.load_queries(split=split) - for split in splits - } + queries_dict = { + split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_names) + } - for split in splits: rerank_search_results_save_path = os.path.join( reranker_search_results_save_dir, f"{split}.json" ) if os.path.exists(rerank_search_results_save_path) and not self.overwrite: - continue + pass + else: + rerank_search_results = reranker( + corpus=corpus, + queries=queries_dict[split], + search_results=no_reranker_search_results_dict[split], + batch_size=reranker_batch_size, + max_length=reranker_max_length, + **kwargs, + ) - rerank_search_results = reranker( - corpus=corpus, - queries=queries_dict[split], - search_results=no_reranker_search_results_dict[split], - batch_size=reranker_batch_size, - max_length=reranker_max_length, - **kwargs, - ) + self.save_search_results( + model_name=str(retriever), + reranker_name=str(reranker), + search_results=rerank_search_results, + output_path=rerank_search_results_save_path, + split=split, + dataset_dir=self.dataset_dir, + ) + else: + for sub_dataset_name in sub_dataset_names: + corpus = self.data_loader.load_corpus(sub_dataset_name=sub_dataset_name) + + queries_dict = { + split: self.data_loader.load_queries(sub_dataset_name=sub_dataset_name) + } + + rerank_search_results_save_path = os.path.join( + reranker_search_results_save_dir, f"{sub_dataset_name}-{split}.json" + ) + + if os.path.exists(rerank_search_results_save_path) and not self.overwrite: + continue + + rerank_search_results = reranker( + corpus=corpus, + queries=queries_dict[split], + search_results=no_reranker_search_results_dict[f"{sub_dataset_name}-{split}"], + batch_size=reranker_batch_size, + max_length=reranker_max_length, + **kwargs, + ) + + self.save_search_results( + model_name=str(retriever), + reranker_name=str(reranker), + search_results=rerank_search_results, + output_path=rerank_search_results_save_path, + split=f"{sub_dataset_name}-{split}", + dataset_dir=self.dataset_dir, + ) - self.save_search_results( - model_name=str(retriever), - reranker_name=str(reranker), - search_results=rerank_search_results, - output_path=rerank_search_results_save_path, - split=split, - dataset_dir=self.dataset_dir, - ) reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir) self.output_eval_results_to_json(reranker_eval_results, os.path.join(reranker_search_results_save_dir, 'eval.json')) \ No newline at end of file diff --git a/FlagEmbedding/evaluation/mteb/__main__.py b/FlagEmbedding/evaluation/mteb/__main__.py index e69de29..56aec42 100644 --- a/FlagEmbedding/evaluation/mteb/__main__.py +++ b/FlagEmbedding/evaluation/mteb/__main__.py @@ -0,0 +1,44 @@ +from transformers import HfArgumentParser + +from FlagEmbedding import FlagAutoModel, FlagAutoReranker +from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator + + +from utils.arguments import MSMARCOEvalArgs +from utils.data_loader import MSMARCODataLoader + + +def get_models(model_args: AbsModelArgs): + retriever = FlagAutoModel.from_finetuned( + model_name_or_path=model_args.embedder_name_or_path, + normalize_embeddings=model_args.normalize_embeddings, + use_fp16=model_args.use_fp16, + query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, + query_instruction_format=model_args.query_instruction_format_for_retrieval, + devices=model_args.devices, + examples_for_task=model_args.examples_for_task, + examples_instruction_format=model_args.examples_instruction_format, + trust_remote_code=model_args.trust_remote_code, + cache_dir=model_args.cache_dir + ) + reranker = None + if model_args.reranker_name_or_path is not None: + reranker = FlagAutoReranker.from_finetuned( + model_name_or_path=model_args.reranker_name_or_path, + peft_path=model_args.reranker_peft_path, + use_fp16=model_args.use_fp16, + use_bf16=model_args.use_bf16, + query_instruction_for_rerank=model_args.query_instruction_for_rerank, + query_instruction_format=model_args.query_instruction_format_for_rerank, + passage_instruction_for_rerank=model_args.passage_instruction_for_rerank, + passage_instruction_format=model_args.passage_instruction_format_for_rerank, + cache_dir=model_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + devices=model_args.devices, + normalize=model_args.normalize, + prompt=model_args.prompt, + cutoff_layers=model_args.cutoff_layers, + compress_layers=model_args.compress_layers, + compress_ratio=model_args.compress_ratio, + ) + return retriever, reranker \ No newline at end of file diff --git a/FlagEmbedding/evaluation/mteb/evaluate.py b/FlagEmbedding/evaluation/mteb/evaluate.py index 6748715..3da10ba 100644 --- a/FlagEmbedding/evaluation/mteb/evaluate.py +++ b/FlagEmbedding/evaluation/mteb/evaluate.py @@ -1,81 +1,329 @@ -from transformers import HfArgumentParser +import multiprocessing +import os +import random +import sys -from FlagEmbedding import AutoFlagModel +import pandas as pd +import torch +import torch.nn.functional as F +import tqdm +import json +import numpy as np +import argparse -from utils.arguments import ModelArgs -from utils.models import SentenceTransformerEncoder, SentenceTransformerReranker -from utils.searcher import EmbeddingModelRetriever, CrossEncoderReranker +import mteb + +from peft import PeftModel +from transformers import AutoModel, AutoTokenizer +from transformers.modeling_outputs import BaseModelOutput +from typing import List +from mteb import MTEB + +from utils import logger, pool, move_to_cuda, get_detailed_instruct, get_task_def_by_task_name_and_type, create_batch_dict, tasks_desc, create_batch_query_dict + +parser = argparse.ArgumentParser(description='evaluation for MTEB benchmark except its Retrieval category') +parser.add_argument('--task-types', nargs='+', default=[], help='task types to evaluate') +parser.add_argument('--output-dir', default='', + type=str, metavar='N', help='output directory') +parser.add_argument('--model-name-or-path', default='tmp-outputs/', + type=str, metavar='N', help='which model to use') +parser.add_argument('--peft-name-or-path', default=None, type=str) +parser.add_argument('--tokenizer-path', default=None) +parser.add_argument('--embedding-path', default=None) +parser.add_argument('--special-token', default=False, type=bool, help='whether to use special token') +parser.add_argument('--zero-shot', default=False, type=bool, help='whether to use zero shot icl') +parser.add_argument('--device', default=0, type=int) +parser.add_argument('--all-num', default=8, type=int) +parser.add_argument('--batch_size', default=32, type=int) +parser.add_argument('--examples-dir', default='/share/chaofan/code/embedder/evaluate_for_icl/examples', type=str) +parser.add_argument('--eight-special-token', default=False, type=bool) +parser.add_argument('--passage-prompt', default=False, type=bool) + +args = parser.parse_args() +base_name: str = args.model_name_or_path.split('/')[-1] +if args.eight_special_token is True: + args.pool_type = 'last_eight' +else: + args.pool_type = 'last' + +logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) +assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg', 'last_eight'], 'pool_type should be cls / avg / last' +os.makedirs(args.output_dir, exist_ok=True) +ALL_NUM = args.all_num + +print(args.special_token) +class DenseEncoder(torch.nn.Module): + def __init__(self, device, **kwargs): + super().__init__() + self.encoder = AutoModel.from_pretrained(args.model_name_or_path, use_cache=False) + if args.peft_name_or_path is not None: + peft_name_or_path = args.peft_name_or_path.split(',') + if args.embedding_path is not None: + self.encoder.set_input_embeddings(torch.load(args.embedding_path)) + elif args.special_token and os.path.exists(os.path.join(peft_name_or_path[-1], 'embedding', 'emb.pth')): + self.encoder.set_input_embeddings(torch.load(os.path.join(peft_name_or_path[-1], 'embedding', 'emb.pth'))) + for peft_path in peft_name_or_path: + self.encoder = PeftModel.from_pretrained(self.encoder, peft_path) + self.encoder = self.encoder.merge_and_unload() + + if args.tokenizer_path is not None: + self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + else: + self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + self.l2_normalize = True + self.prompt = None + self.prefix = '' + self.suffix = '' + self.gpu_count = torch.cuda.device_count() + + self.encoder.half() + self.encoder.eval() + # self.encoder.cuda() + self.device = device + self.encoder = self.encoder.to(device) + + self.eight_special_token = args.eight_special_token + if args.eight_special_token: + self.special_tokens = [self.tokenizer.eos_token_id] * 7 + else: + self.special_tokens = [] + + self.passage_prompt= args.passage_prompt + + self.batch_size = args.batch_size + + # if self.gpu_count > 1: + # self.encoder = torch.nn.DataParallel(self.encoder) + + @torch.no_grad() + def encode(self, sentences, **kwargs) -> np.ndarray: + """ Returns a list of embeddings for the given sentences. + Args: + sentences (`List[str]`): List of sentences to encode + batch_size (`int`): Batch size for the encoding + + Returns: + `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences + """ + + input_texts: List[str] = [self.prompt + s for s in sentences] + + encoded_embeds = [] + batch_size = self.batch_size * self.gpu_count + for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10): + batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size] + + batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens) + # if self.device == 0: + # print(self.tokenizer.decode(batch_dict['input_ids'][0])) + batch_dict = batch_dict.to(self.device) + # batch_dict = move_to_cuda(batch_dict) + + with torch.cuda.amp.autocast(): + outputs: BaseModelOutput = self.encoder(**batch_dict) + embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) + if self.l2_normalize: + embeds = F.normalize(embeds, p=2, dim=-1) + encoded_embeds.append(embeds.cpu().numpy()) + + return np.concatenate(encoded_embeds, axis=0) + + @torch.no_grad() + def encode_queries(self, sentences, **kwargs) -> np.ndarray: + """ Returns a list of embeddings for the given sentences. + Args: + sentences (`List[str]`): List of sentences to encode + batch_size (`int`): Batch size for the encoding + + Returns: + `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences + """ + + input_texts: List[str] = [self.prompt + s for s in sentences] + + encoded_embeds = [] + batch_size = self.batch_size * self.gpu_count + for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10): + batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size] + + batch_dict = create_batch_query_dict(self.tokenizer, self.prefix, self.suffix, batch_input_texts, special_tokens=self.special_tokens) + batch_dict = move_to_cuda(batch_dict) + + with torch.cuda.amp.autocast(): + outputs: BaseModelOutput = self.encoder(**batch_dict) + embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) + if self.l2_normalize: + embeds = F.normalize(embeds, p=2, dim=-1) + encoded_embeds.append(embeds.cpu().numpy()) + + return np.concatenate(encoded_embeds, axis=0) + + @torch.no_grad() + def encode_corpus(self, sentences, **kwargs) -> np.ndarray: + """ Returns a list of embeddings for the given sentences. + Args: + sentences (`List[str]`): List of sentences to encode + batch_size (`int`): Batch size for the encoding + + Returns: + `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences + """ + + if isinstance(sentences[0], str): + input_texts: List[str] = [s for s in sentences] + else: + input_texts: List[str] = [(s['text'] + ' ' + s['title']).strip() for s in sentences] + + encoded_embeds = [] + batch_size = self.batch_size * self.gpu_count + for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10): + batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size] + + batch_dict = create_batch_dict(self.tokenizer, batch_input_texts, special_tokens=self.special_tokens, passage_prompt=self.passage_prompt) + batch_dict = move_to_cuda(batch_dict) + + with torch.cuda.amp.autocast(): + outputs: BaseModelOutput = self.encoder(**batch_dict) + embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) + if self.l2_normalize: + embeds = F.normalize(embeds, p=2, dim=-1) + encoded_embeds.append(embeds.cpu().numpy()) + + return np.concatenate(encoded_embeds, axis=0) + + def set_prompt(self, prompt: str): + self.prompt = prompt + + def set_prefix(self, prefix: str): + self.prefix = prefix + + def set_suffix(self, suffix: str): + self.suffix = suffix -def get_models(model_args: ModelArgs): - embedding_model = AutoFlagModel.from_finetuned( - model_name_or_path, - normalize_embeddings, - use_fp16, - query_instruction_for_retrieval, - query_instruction_format, - devices, - examples_for_task, - examples_instruction_format, - trust_remote_code, - cache_dir - ) - cross_encoder = None - if model_args.reranker is not None: - cross_encoder = SentenceTransformerReranker( - model_name_or_path, - peft_path, - use_fp16, - use_bf16, - query_instruction_for_rerank, - query_instruction_format, - passage_instruction_for_rerank, - passage_instruction_format, - cache_dir, - trust_remote_code, - devices - ) - return embedding_model, cross_encoder +def main(device, all_pairs): + torch.cuda.set_device(device) + model = DenseEncoder(device) -def main(): - parser = HfArgumentParser([ModelArgs, EvalArgs]) - model_args, eval_args = parser.parse_args_into_dataclasses() - model_args: ModelArgs - eval_args: EvalArgs + os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}' - embedding_model, cross_encoder = get_models(model_args) - - evaluation = AIRBench( - benchmark_version=eval_args.benchmark_version, - task_types=eval_args.task_types, - domains=eval_args.domains, - languages=eval_args.languages, - splits=eval_args.splits, - cache_dir=eval_args.cache_dir, - ) - - retriever = EmbeddingModelRetriever( - embedding_model, - search_top_k=eval_args.search_top_k, - corpus_chunk_size=model_args.corpus_chunk_size, - ) - - if cross_encoder is not None: - reranker = CrossEncoderReranker( - cross_encoder, - rerank_top_k=eval_args.rerank_top_k, - ) + ArxivClusteringP2P_FLAG = False + tmp = None + for i, p in enumerate(all_pairs): + if p[1] == 'ArxivClusteringP2P': + ArxivClusteringP2P_FLAG = True + tmp = p + break + if ArxivClusteringP2P_FLAG: + all_pairs.remove(tmp) + + os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}' + if ArxivClusteringP2P_FLAG is False: + length = len(all_pairs) + start = device * length // ALL_NUM + if device == ALL_NUM - 1: + end = length + else: + end = (device + 1) * length // ALL_NUM + all_pairs = all_pairs[start: end] else: - reranker = None - - evaluation.run( - retriever, - reranker=reranker, - output_dir=eval_args.output_dir, - overwrite=eval_args.overwrite, - ) + if device == ALL_NUM - 1: + all_pairs = [tmp] + else: + all_num = ALL_NUM - 1 + length = len(all_pairs) + start = device * length // ALL_NUM + if device == all_num - 1: + end = length + else: + end = (device + 1) * length // all_num + all_pairs = all_pairs[start: end] + + for (task_type, task_name) in all_pairs: + task_def: str = get_task_def_by_task_name_and_type(task_name=task_name, task_type=task_type) + prompt: str = get_detailed_instruct(task_def, args.special_token) + model.set_prompt(prompt=prompt) + + eg_file_path = f'{args.examples_dir}/{task_name}.csv' + eg_paris = [] + if args.zero_shot: + eg_paris = [] + else: + df = pd.read_csv(eg_file_path) + for i in range(len(df)): + eg_paris.append((get_detailed_instruct( + task_def, + args.special_token + ) + df[df.keys()[0]][i], df[df.keys()[1]][i])) + if args.special_token: + if len(eg_paris) > 0: + prefix = '\n\n'.join(['\n'.join(eg_paris[idx]) for idx in range(len(eg_paris))]) + '\n\n' + else: + prefix = '' + suffix = '\n' + else: + if len(eg_paris) > 0: + prefix = '\n\n'.join(['\nResponse: '.join(eg_paris[idx]) for idx in range(len(eg_paris))]) + '\n\n' + else: + prefix = '' + suffix = '\nResponse:' + model.set_prefix(prefix) + model.set_suffix(suffix) + + logger.info('Set prompt: {}'.format(prompt)) + + # disable l2 normalize for classification tasks, as it achieves slightly better results + if task_type == 'Classification': + logger.info('Set l2_normalize to False for classification task') + model.l2_normalize = False + else: + model.l2_normalize = True + logger.info('Set l2_normalize to {}'.format(model.l2_normalize)) + + sub_eval = MTEB(tasks=[task_name], task_langs=['en']) + + logger.info('Running evaluation for task: {}, type: {}'.format(task_name, task_type)) + + # eval_splits = ["test"] if "test" in task_cls.description["eval_splits"] else task_cls.description["eval_splits"] + + result_flag = False + model.batch_size = args.batch_size + while result_flag is False: + try: + sub_eval.run( + model, + output_folder=args.output_dir + ) + result_flag = True + except Exception as e: + model.batch_size -= 4 + print(e) -if __name__ == "__main__": - main() \ No newline at end of file +if __name__ == '__main__': + processes = [] + multiprocessing.set_start_method('spawn') + + random.seed(30) + args.task_types = [t for t in args.task_types if t.strip()] + all_pairs = [] + for task_type in args.task_types: + if task_type in tasks_desc.keys(): + for task_name in tasks_desc[task_type]: + all_pairs.append((task_type, task_name)) + for task_type in tasks_desc.keys(): + for v in tasks_desc[task_type]: + if v in args.task_types: + all_pairs.append((task_type, v)) + all_pairs = list(set(all_pairs)) + random.shuffle(all_pairs) + + for i in range(ALL_NUM): + # i = 7 + process = multiprocessing.Process(target=main, args=(i,all_pairs,)) + processes.append(process) + process.start() + + for process in processes: + process.join() \ No newline at end of file diff --git a/FlagEmbedding/evaluation/mteb/utils.py b/FlagEmbedding/evaluation/mteb/utils.py new file mode 100644 index 0000000..137fa69 --- /dev/null +++ b/FlagEmbedding/evaluation/mteb/utils.py @@ -0,0 +1,365 @@ +import sys + +import torch +import logging + +from torch import Tensor +from transformers import PreTrainedTokenizerFast, BatchEncoding +from typing import Mapping, Dict, List + + +def _setup_logger(): + log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_format) + logger.handlers = [console_handler] + + return logger + + +logger = _setup_logger() + + +def move_to_cuda(sample): + if len(sample) == 0: + return {} + + def _move_to_cuda(maybe_tensor): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.cuda(non_blocking=True) + elif isinstance(maybe_tensor, dict): + return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} + elif isinstance(maybe_tensor, list): + return [_move_to_cuda(x) for x in maybe_tensor] + elif isinstance(maybe_tensor, tuple): + return tuple([_move_to_cuda(x) for x in maybe_tensor]) + elif isinstance(maybe_tensor, Mapping): + return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) + else: + return maybe_tensor + + return _move_to_cuda(sample) + + +def pool(last_hidden_states: Tensor, + attention_mask: Tensor, + pool_type: str) -> Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + + if pool_type == "avg": + emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pool_type == "weightedavg": # position-weighted mean pooling from SGPT (https://arxiv.org/abs/2202.08904) + attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0] + s = torch.sum(last_hidden * attention_mask.unsqueeze(-1).float(), dim=1) + d = attention_mask.sum(dim=1, keepdim=True).float() + emb = s / d + elif pool_type == "cls": + emb = last_hidden[:, 0] + elif pool_type == "last": + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + emb = last_hidden[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden.shape[0] + emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] + elif pool_type == "last_eight": + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + emb = torch.mean(last_hidden[:, -8:, :], dim=-2) + else: + sys.exit() + else: + raise ValueError(f"pool_type {pool_type} not supported") + + return emb + + +def create_batch_dict(tokenizer: PreTrainedTokenizerFast, input_texts: List[str], max_length: int = 512, special_tokens: list = [], passage_prompt: bool = False) -> BatchEncoding: + passage_suffix = '' + passage_suffix_len = 0 + if passage_prompt: + passage_suffix = '\nSummarize the above passage: ' + passage_suffix_len = len(tokenizer(passage_suffix, add_special_tokens=False)['input_ids']) + inputs = tokenizer( + input_texts, + max_length=max_length - len(tokenizer('', add_special_tokens=False)['input_ids']) - len( + tokenizer('', add_special_tokens=False)['input_ids']) - len(special_tokens) - passage_suffix_len, + return_token_type_ids=False, + truncation=True, + return_tensors=None, + add_special_tokens=False + ) + input_texts = tokenizer.batch_decode(inputs['input_ids']) + if len(special_tokens) > 0: + for i in range(len(input_texts)): + input_texts[i] = input_texts[i] + passage_suffix + tokenizer.eos_token * 7 + else: + for i in range(len(input_texts)): + input_texts[i] = input_texts[i] + passage_suffix + + return tokenizer( + input_texts, + max_length=max_length, + padding=True, + pad_to_multiple_of=8, + return_token_type_ids=False, + truncation=True, + return_tensors='pt' + ) + +def create_batch_query_dict(tokenizer: PreTrainedTokenizerFast, prefix: str, suffix: str, input_texts: List[str], max_length: int = 512, special_tokens: list = []) -> BatchEncoding: + inputs = tokenizer( + input_texts, + max_length=max_length - len(tokenizer('', add_special_tokens=False)['input_ids']) - len(tokenizer('\n', add_special_tokens=False)['input_ids']) - len(special_tokens), + return_token_type_ids=False, + truncation=True, + return_tensors=None, + add_special_tokens=False + ) + prefix_ids = tokenizer(prefix, add_special_tokens=False)['input_ids'] + suffix_ids = tokenizer(suffix, add_special_tokens=False)['input_ids'] + new_max_length = (len(prefix_ids) + len(suffix_ids) + max_length) // 8 * 8 + 8 + + input_texts = tokenizer.batch_decode(inputs['input_ids']) + for i in range(len(input_texts)): + if len(special_tokens) > 0: + input_texts[i] = prefix + input_texts[i] + suffix + tokenizer.eos_token * 7 + else: + input_texts[i] = prefix + input_texts[i] + suffix + + return tokenizer( + input_texts, + max_length=new_max_length, + padding=True, + pad_to_multiple_of=8, + return_token_type_ids=False, + truncation=True, + return_tensors='pt' + ) + + +def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str: + if task_type in ['STS']: + return "Retrieve semantically similar text." + + if task_type in ['Summarization']: + return "Given a news summary, retrieve other semantically similar summaries." + + if task_type in ['BitextMining']: + return "Retrieve parallel sentences." + + if task_type in ['Classification']: + task_name_to_instruct: Dict[str, str] = { + 'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual.', + 'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment.', + 'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category.', + 'Banking77Classification': 'Given a online banking query, find the corresponding intents.', + 'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.', + 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset.', + 'MassiveIntentClassification': 'Given a user utterance as query, find the user intents.', + 'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios.', + 'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation.', + 'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation.', + 'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic.', + 'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral.', + # C-MTEB eval instructions + 'TNews': 'Classify the fine-grained category of the given news title.', + 'IFlyTek': 'Given an App description text, find the appropriate fine-grained category.', + 'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative.', + 'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative.', + 'OnlineShopping': 'Classify the customer review for online shopping into positive or negative.', + 'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative.', + } + return task_name_to_instruct[task_name] + + if task_type in ['Clustering']: + task_name_to_instruct: Dict[str, str] = { + 'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts.', + 'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles.', + 'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts.', + 'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles.', + 'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts.', + 'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles.', + 'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles.', + 'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts.', + 'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles.', + 'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs.', + 'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles.', + # C-MTEB eval instructions + 'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles.', + 'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts.', + 'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles.', + 'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents.', + } + return task_name_to_instruct[task_name] + + if task_type in ['Reranking', 'PairClassification']: + task_name_to_instruct: Dict[str, str] = { + 'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum.', + 'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history.', + 'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers.', + 'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum.', + 'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum.', + 'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet.', + 'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet.', + # C-MTEB eval instructions + 'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question.', + 'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question.', + 'Ocnli': 'Retrieve semantically similar text.', + 'Cmnli': 'Retrieve semantically similar text.', + } + return task_name_to_instruct[task_name] + + if task_type in ['Retrieval']: + if task_name.lower().startswith('cqadupstack'): + return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.' + + task_name_to_instruct: Dict[str, str] = { + 'ArguAna': 'Given a claim, find documents that refute the claim.', + 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim.', + 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia.', + 'FEVER': 'Given a claim, retrieve documents that support or refute the claim.', + 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question.', + 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question.', + 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.', + 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question.', + 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question.', + 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question.', + 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.', + 'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim.', + 'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.', + 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query.', + # C-MTEB eval instructions + 'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query.', + 'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question.', + 'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question.', + 'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products.', + 'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question.', + 'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos.', + } + + # add lower case keys to match some beir names + task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) + # other cases where lower case match still doesn't work + task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID'] + task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER'] + task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia'] + task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020'] + task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018'] + task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval'] + + # for miracl evaluation + task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question.' + + return task_name_to_instruct[task_name] + + raise ValueError(f"No instruction config for task {task_name} with type {task_type}") + + +def get_detailed_instruct(task_description: str, special_token: bool) -> str: + if not task_description: + return '' + + if special_token: + return f'{task_description}\n' + else: + return f'Instruct: {task_description}\nQuery: ' + + +tasks_desc = { + 'Retrieval': [ + 'ArguAna', + 'ClimateFEVER', + 'DBPedia', + 'FEVER', + 'FiQA2018', + 'HotpotQA', + 'MSMARCO', + 'NFCorpus', + 'NQ', + 'QuoraRetrieval', + 'SCIDOCS', + 'SciFact', + 'Touche2020', + 'TRECCOVID', + 'CQADupstackAndroidRetrieval.py', + 'CQADupstackEnglishRetrieval.py', + 'CQADupstackGamingRetrieval', + 'CQADupstackGisRetrieval', + 'CQADupstackMathematicaRetrieval', + 'CQADupstackPhysicsRetrieval', + 'CQADupstackProgrammersRetrieval', + 'CQADupstackStatsRetrieval', + 'CQADupstackTexRetrieval', + 'CQADupstackUnixRetrieval', + 'CQADupstackWebmastersRetrieval', + 'CQADupstackWordpressRetrieval' + ], + 'Classification': [ + # 12 + 'AmazonCounterfactualClassification', + 'AmazonPolarityClassification', + 'AmazonReviewsClassification', + 'Banking77Classification', + 'EmotionClassification', + 'ImdbClassification', + 'MassiveIntentClassification', + 'MassiveScenarioClassification', + 'MTOPDomainClassification', + 'MTOPIntentClassification', + 'ToxicConversationsClassification', + 'TweetSentimentExtractionClassification', + ], + 'Clustering': [ + # 11 + 'ArxivClusteringP2P', + 'ArxivClusteringS2S', + 'BiorxivClusteringP2P', + 'BiorxivClusteringS2S', + 'MedrxivClusteringP2P', + 'MedrxivClusteringS2S', + 'RedditClustering', + 'RedditClusteringP2P', + 'StackExchangeClustering', + 'StackExchangeClusteringP2P', + 'TwentyNewsgroupsClustering', + ], + 'PairClassification': [ + # 3 + 'SprintDuplicateQuestions', + 'TwitterSemEval2015', + 'TwitterURLCorpus', + ], + 'Reranking': [ + # 4 + 'AskUbuntuDupQuestions', + 'MindSmallReranking', + 'SciDocsRR', + 'StackOverflowDupQuestions', + ], + 'STS': [ + # 10 + 'BIOSSES', + 'SICK-R', + 'STS12', + 'STS13', + 'STS14', + 'STS15', + 'STS16', + 'STS17', + 'STS22', + 'STSBenchmark', + ], + 'Summarization': [ + # 1 + 'SummEval', + ] +} \ No newline at end of file diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-gemini.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-gemini.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-gemini.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-gemini.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-gpt3.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-gpt3.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-gpt3.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-gpt3.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-llama2.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-llama2.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-llama2.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-llama2.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-llm-survey.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-llm-survey.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/arxiv-llm-survey.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/arxiv-llm-survey.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/book-a-brief-history-of-time_stephen-hawking.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/book-a-brief-history-of-time_stephen-hawking.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/book-a-brief-history-of-time_stephen-hawking.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/book-a-brief-history-of-time_stephen-hawking.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/book-origin-of-species_darwin.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/book-origin-of-species_darwin.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/book-origin-of-species_darwin.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/book-origin-of-species_darwin.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_1.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_1.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_1.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_1.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_2.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_2.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_2.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_2.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_3.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_3.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_3.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_100k-200k_3.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_30k-40k_10-merged.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_30k-40k_10-merged.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_30k-40k_10-merged.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_30k-40k_10-merged.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_40k-50k_5-merged.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_40k-50k_5-merged.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/healthcare-pubmed_40k-50k_5-merged.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/healthcare-pubmed_40k-50k_5-merged.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_300k-400k.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_300k-400k.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_300k-400k.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_300k-400k.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_400k-500k.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_400k-500k.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_400k-500k.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_400k-500k.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_500k-600k.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_500k-600k.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_500k-600k.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_500k-600k.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_600k-700k.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_600k-700k.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/long-doc/law-lex_files_600k-700k.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/long-doc/law-lex_files_600k-700k.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/arxiv.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/arxiv.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/arxiv.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/arxiv.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/finance.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/finance.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/finance.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/finance.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/healthcare.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/healthcare.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/healthcare.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/healthcare.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/law.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/law.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/law.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/law.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/msmarco.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/msmarco.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/msmarco.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/msmarco.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/news.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/news.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/news.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/news.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/web.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/web.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/web.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/web.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/wiki.jsonl b/FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/wiki.jsonl similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/air-bench/qa/wiki.jsonl rename to FlagEmbedding/evaluation/mteb/utils/examples/air-bench/qa/wiki.jsonl diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/AmazonCounterfactualClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/AmazonCounterfactualClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/AmazonCounterfactualClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/AmazonCounterfactualClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/AmazonPolarityClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/AmazonPolarityClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/AmazonPolarityClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/AmazonPolarityClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/AmazonReviewsClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/AmazonReviewsClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/AmazonReviewsClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/AmazonReviewsClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/ArguAna.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/ArguAna.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/ArguAna.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/ArguAna.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/ArxivClusteringP2P.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/ArxivClusteringP2P.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/ArxivClusteringP2P.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/ArxivClusteringP2P.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/ArxivClusteringS2S.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/ArxivClusteringS2S.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/ArxivClusteringS2S.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/ArxivClusteringS2S.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/AskUbuntuDupQuestions.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/AskUbuntuDupQuestions.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/AskUbuntuDupQuestions.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/AskUbuntuDupQuestions.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/BIOSSES.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/BIOSSES.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/BIOSSES.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/BIOSSES.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/Banking77Classification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/Banking77Classification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/Banking77Classification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/Banking77Classification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/BiorxivClusteringP2P.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/BiorxivClusteringP2P.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/BiorxivClusteringP2P.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/BiorxivClusteringP2P.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/BiorxivClusteringS2S.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/BiorxivClusteringS2S.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/BiorxivClusteringS2S.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/BiorxivClusteringS2S.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/CQADupstack.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/CQADupstack.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/CQADupstack.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/CQADupstack.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/CQADupstackRetrieval.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/CQADupstackRetrieval.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/CQADupstackRetrieval.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/CQADupstackRetrieval.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/ClimateFEVER.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/ClimateFEVER.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/ClimateFEVER.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/ClimateFEVER.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/DBPedia.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/DBPedia.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/DBPedia.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/DBPedia.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/EmotionClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/EmotionClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/EmotionClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/EmotionClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/FEVER.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/FEVER.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/FEVER.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/FEVER.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/FiQA2018.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/FiQA2018.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/FiQA2018.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/FiQA2018.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/HotpotQA.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/HotpotQA.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/HotpotQA.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/HotpotQA.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/ImdbClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/ImdbClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/ImdbClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/ImdbClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MSMARCO.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MSMARCO.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MSMARCO.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MSMARCO.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MTOPDomainClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MTOPDomainClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MTOPDomainClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MTOPDomainClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MTOPIntentClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MTOPIntentClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MTOPIntentClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MTOPIntentClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MassiveIntentClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MassiveIntentClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MassiveIntentClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MassiveIntentClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MassiveScenarioClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MassiveScenarioClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MassiveScenarioClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MassiveScenarioClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MedrxivClusteringP2P.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MedrxivClusteringP2P.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MedrxivClusteringP2P.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MedrxivClusteringP2P.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MedrxivClusteringS2S.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MedrxivClusteringS2S.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MedrxivClusteringS2S.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MedrxivClusteringS2S.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/MindSmallReranking.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/MindSmallReranking.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/MindSmallReranking.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/MindSmallReranking.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/NFCorpus.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/NFCorpus.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/NFCorpus.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/NFCorpus.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/NQ.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/NQ.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/NQ.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/NQ.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/QuoraRetrieval.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/QuoraRetrieval.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/QuoraRetrieval.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/QuoraRetrieval.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/RedditClustering.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/RedditClustering.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/RedditClustering.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/RedditClustering.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/RedditClusteringP2P.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/RedditClusteringP2P.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/RedditClusteringP2P.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/RedditClusteringP2P.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/SCIDOCS.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/SCIDOCS.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/SCIDOCS.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/SCIDOCS.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/SICK-R.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/SICK-R.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/SICK-R.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/SICK-R.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS12.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS12.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS12.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS12.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS13.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS13.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS13.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS13.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS14.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS14.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS14.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS14.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS15.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS15.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS15.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS15.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS16.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS16.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS16.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS16.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS17.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS17.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS17.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS17.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS22.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS22.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STS22.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STS22.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/STSBenchmark.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/STSBenchmark.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/STSBenchmark.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/STSBenchmark.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/SciDocsRR.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/SciDocsRR.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/SciDocsRR.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/SciDocsRR.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/SciFact.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/SciFact.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/SciFact.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/SciFact.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/SprintDuplicateQuestions.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/SprintDuplicateQuestions.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/SprintDuplicateQuestions.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/SprintDuplicateQuestions.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/StackExchangeClustering.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/StackExchangeClustering.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/StackExchangeClustering.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/StackExchangeClustering.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/StackExchangeClusteringP2P.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/StackExchangeClusteringP2P.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/StackExchangeClusteringP2P.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/StackExchangeClusteringP2P.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/StackOverflowDupQuestions.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/StackOverflowDupQuestions.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/StackOverflowDupQuestions.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/StackOverflowDupQuestions.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/SummEval.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/SummEval.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/SummEval.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/SummEval.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/TRECCOVID.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/TRECCOVID.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/TRECCOVID.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/TRECCOVID.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/Touche2020.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/Touche2020.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/Touche2020.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/Touche2020.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/ToxicConversationsClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/ToxicConversationsClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/ToxicConversationsClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/ToxicConversationsClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/TweetSentimentExtractionClassification.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/TweetSentimentExtractionClassification.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/TweetSentimentExtractionClassification.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/TweetSentimentExtractionClassification.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/TwentyNewsgroupsClustering.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/TwentyNewsgroupsClustering.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/TwentyNewsgroupsClustering.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/TwentyNewsgroupsClustering.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/TwitterSemEval2015.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/TwitterSemEval2015.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/TwitterSemEval2015.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/TwitterSemEval2015.csv diff --git a/FlagEmbedding_old/llm_dense_retriever/examples/mteb/TwitterURLCorpus.csv b/FlagEmbedding/evaluation/mteb/utils/examples/mteb/TwitterURLCorpus.csv similarity index 100% rename from FlagEmbedding_old/llm_dense_retriever/examples/mteb/TwitterURLCorpus.csv rename to FlagEmbedding/evaluation/mteb/utils/examples/mteb/TwitterURLCorpus.csv diff --git a/FlagEmbedding/evaluation/mteb/utils/prompts.py b/FlagEmbedding/evaluation/mteb/utils/prompts.py new file mode 100644 index 0000000..7cd874b --- /dev/null +++ b/FlagEmbedding/evaluation/mteb/utils/prompts.py @@ -0,0 +1,216 @@ +import sys + +import torch + + +def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str: + if task_type in ['STS']: + return "Retrieve semantically similar text." + + if task_type in ['Summarization']: + return "Given a news summary, retrieve other semantically similar summaries." + + if task_type in ['BitextMining']: + return "Retrieve parallel sentences." + + if task_type in ['Classification']: + task_name_to_instruct: Dict[str, str] = { + 'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual.', + 'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment.', + 'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category.', + 'Banking77Classification': 'Given a online banking query, find the corresponding intents.', + 'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.', + 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset.', + 'MassiveIntentClassification': 'Given a user utterance as query, find the user intents.', + 'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios.', + 'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation.', + 'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation.', + 'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic.', + 'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral.', + # C-MTEB eval instructions + 'TNews': 'Classify the fine-grained category of the given news title.', + 'IFlyTek': 'Given an App description text, find the appropriate fine-grained category.', + 'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative.', + 'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative.', + 'OnlineShopping': 'Classify the customer review for online shopping into positive or negative.', + 'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative.', + } + return task_name_to_instruct[task_name] + + if task_type in ['Clustering']: + task_name_to_instruct: Dict[str, str] = { + 'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts.', + 'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles.', + 'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts.', + 'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles.', + 'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts.', + 'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles.', + 'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles.', + 'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts.', + 'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles.', + 'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs.', + 'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles.', + # C-MTEB eval instructions + 'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles.', + 'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts.', + 'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles.', + 'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents.', + } + return task_name_to_instruct[task_name] + + if task_type in ['Reranking', 'PairClassification']: + task_name_to_instruct: Dict[str, str] = { + 'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum.', + 'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history.', + 'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers.', + 'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum.', + 'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum.', + 'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet.', + 'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet.', + # C-MTEB eval instructions + 'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question.', + 'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question.', + 'Ocnli': 'Retrieve semantically similar text.', + 'Cmnli': 'Retrieve semantically similar text.', + } + return task_name_to_instruct[task_name] + + if task_type in ['Retrieval']: + if task_name.lower().startswith('cqadupstack'): + return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.' + + task_name_to_instruct: Dict[str, str] = { + 'ArguAna': 'Given a claim, find documents that refute the claim.', + 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim.', + 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia.', + 'FEVER': 'Given a claim, retrieve documents that support or refute the claim.', + 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question.', + 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question.', + 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query.', + 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question.', + 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question.', + 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question.', + 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.', + 'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim.', + 'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.', + 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query.', + # C-MTEB eval instructions + 'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query.', + 'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question.', + 'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question.', + 'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question.', + 'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products.', + 'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question.', + 'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos.', + } + + # add lower case keys to match some beir names + task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) + # other cases where lower case match still doesn't work + task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID'] + task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER'] + task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia'] + task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020'] + task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018'] + task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval'] + + # for miracl evaluation + task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question.' + + return task_name_to_instruct[task_name] + + raise ValueError(f"No instruction config for task {task_name} with type {task_type}") + + +tasks_desc = { + 'Retrieval': [ + 'ArguAna', + 'ClimateFEVER', + 'DBPedia', + 'FEVER', + 'FiQA2018', + 'HotpotQA', + 'MSMARCO', + 'NFCorpus', + 'NQ', + 'QuoraRetrieval', + 'SCIDOCS', + 'SciFact', + 'Touche2020', + 'TRECCOVID', + 'CQADupstackAndroidRetrieval', + 'CQADupstackEnglishRetrieval', + 'CQADupstackGamingRetrieval', + 'CQADupstackGisRetrieval', + 'CQADupstackMathematicaRetrieval', + 'CQADupstackPhysicsRetrieval', + 'CQADupstackProgrammersRetrieval', + 'CQADupstackStatsRetrieval', + 'CQADupstackTexRetrieval', + 'CQADupstackUnixRetrieval', + 'CQADupstackWebmastersRetrieval', + 'CQADupstackWordpressRetrieval' + ], + 'Classification': [ + # 12 + 'AmazonCounterfactualClassification', + 'AmazonPolarityClassification', + 'AmazonReviewsClassification', + 'Banking77Classification', + 'EmotionClassification', + 'ImdbClassification', + 'MassiveIntentClassification', + 'MassiveScenarioClassification', + 'MTOPDomainClassification', + 'MTOPIntentClassification', + 'ToxicConversationsClassification', + 'TweetSentimentExtractionClassification', + ], + 'Clustering': [ + # 11 + 'ArxivClusteringP2P', + 'ArxivClusteringS2S', + 'BiorxivClusteringP2P', + 'BiorxivClusteringS2S', + 'MedrxivClusteringP2P', + 'MedrxivClusteringS2S', + 'RedditClustering', + 'RedditClusteringP2P', + 'StackExchangeClustering', + 'StackExchangeClusteringP2P', + 'TwentyNewsgroupsClustering', + ], + 'PairClassification': [ + # 3 + 'SprintDuplicateQuestions', + 'TwitterSemEval2015', + 'TwitterURLCorpus', + ], + 'Reranking': [ + # 4 + 'AskUbuntuDupQuestions', + 'MindSmallReranking', + 'SciDocsRR', + 'StackOverflowDupQuestions', + ], + 'STS': [ + # 10 + 'BIOSSES', + 'SICK-R', + 'STS12', + 'STS13', + 'STS14', + 'STS15', + 'STS16', + 'STS17', + 'STS22', + 'STSBenchmark', + ], + 'Summarization': [ + # 1 + 'SummEval', + ] +} \ No newline at end of file diff --git a/setup.py b/setup.py index 63adb39..4bd0e69 100644 --- a/setup.py +++ b/setup.py @@ -20,5 +20,8 @@ setup( 'accelerate>=0.20.1', 'sentence_transformers', 'peft', + 'beir', + 'deepspeed', + 'flash-attn' ], )