diff --git a/FlagEmbedding/abc/evaluation/data_loader.py b/FlagEmbedding/abc/evaluation/data_loader.py index 2ce0254..82339e9 100644 --- a/FlagEmbedding/abc/evaluation/data_loader.py +++ b/FlagEmbedding/abc/evaluation/data_loader.py @@ -133,7 +133,7 @@ class AbsEvalDataLoader(ABC): else: corpus_data = datasets.load_dataset('json', data_files=corpus_path, cache_dir=self.cache_dir)['train'] - corpus = {e['id']: {'text': e['text']} for e in corpus_data} + corpus = {e['id']: {'title', e.get('title', ""), 'text': e['text']} for e in corpus_data} return datasets.DatasetDict(corpus) def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: diff --git a/FlagEmbedding/abc/evaluation/searcher.py b/FlagEmbedding/abc/evaluation/searcher.py index d0fd90c..09c8b47 100644 --- a/FlagEmbedding/abc/evaluation/searcher.py +++ b/FlagEmbedding/abc/evaluation/searcher.py @@ -99,9 +99,6 @@ class EvalDenseRetriever(EvalRetriever): corpus_emb = np.load(os.path.join(corpus_embd_save_dir, "doc.npy")) else: corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs) - if corpus_embd_save_dir is not None: - os.makedirs(corpus_embd_save_dir, exist_ok=True) - np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb) else: corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs) @@ -112,6 +109,10 @@ class EvalDenseRetriever(EvalRetriever): corpus_emb = corpus_emb["dense_vecs"] if isinstance(queries_emb, dict): queries_emb = queries_emb["dense_vecs"] + + if corpus_embd_save_dir is not None and not os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")): + os.makedirs(corpus_embd_save_dir, exist_ok=True) + np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb) faiss_index = index(corpus_embeddings=corpus_emb) all_scores, all_indices = search(query_embeddings=queries_emb, faiss_index=faiss_index, k=self.search_top_k) @@ -174,8 +175,11 @@ class EvalReranker: ) # generate sentence pairs sentence_pairs = [] + pairs = [] for qid in search_results: for docid in search_results[qid]: + print(corpus[docid]) + sys.eixt() sentence_pairs.append( { "qid": qid, @@ -185,9 +189,19 @@ class EvalReranker: else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip(), } ) - pairs = [(e["query"], e["doc"]) for e in sentence_pairs] + pairs.append( + ( + queries[qid], + corpus[docid]["text"] if "title" not in corpus[docid] + else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip() + ) + ) + # pairs = [(e["query"], e["doc"]) for e in sentence_pairs] # compute scores scores = self.reranker.compute_score(pairs) + # print(scores) + # print(self.reranker.compute_score([pairs[0]])) + # print(pairs[0], sentence_pairs[0]) for i, score in enumerate(scores): sentence_pairs[i]["score"] = float(score) # rerank diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index 864c742..cd8c294 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -126,6 +126,8 @@ class AbsReranker(ABC): sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], **kwargs ): + if isinstance(sentence_pairs[0], str): + sentence_pairs = [sentence_pairs] sentence_pairs = self.get_detailed_inputs(sentence_pairs) if isinstance(sentence_pairs, str) or len(self.target_devices) == 1: diff --git a/FlagEmbedding/evaluation/beir/run.sh b/FlagEmbedding/evaluation/beir/run.sh index da7b993..5ac6908 100644 --- a/FlagEmbedding/evaluation/beir/run.sh +++ b/FlagEmbedding/evaluation/beir/run.sh @@ -1,13 +1,13 @@ python __main__.py \ --dataset_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR \ --embedder_name_or_path BAAI/bge-large-en-v1.5 \ ---reranker_name_or_path BAAI/bge-reranker-large \ +--reranker_name_or_path BAAI/bge-reranker-v2-m3 \ --query_instruction_for_retrieval "Represent this sentence for searching relevant passages: " \ --use_fp16 True \ --devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \ --cache_dir /share/shared_models \ --corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR_passage_embds \ ---reranker_max_length 512 \ ---dataset_names arguana \ ---use_special_instructions True +--reranker_max_length 1024 \ +--dataset_names trec-covid webis-touche2020 \ +--use_special_instructions False diff --git a/FlagEmbedding/evaluation/msmarco/__init__.py b/FlagEmbedding/evaluation/msmarco/__init__.py new file mode 100644 index 0000000..01d5a9e --- /dev/null +++ b/FlagEmbedding/evaluation/msmarco/__init__.py @@ -0,0 +1,14 @@ +from FlagEmbedding.abc.evaluation import ( + AbsEvalArgs as MSMARCOEvalArgs, + AbsEvalModelArgs as MSMARCOEvalModelArgs, +) + +from .data_loader import MSMARCOEvalDataLoader +from .runner import MSMARCOEvalRunner + +__all__ = [ + "MSMARCOEvalArgs", + "MSMARCOEvalModelArgs", + "MSMARCOEvalRunner", + "MSMARCOEvalDataLoader", +] diff --git a/FlagEmbedding/evaluation/msmarco/__main__.py b/FlagEmbedding/evaluation/msmarco/__main__.py index 32f2052..c2f2f94 100644 --- a/FlagEmbedding/evaluation/msmarco/__main__.py +++ b/FlagEmbedding/evaluation/msmarco/__main__.py @@ -1,95 +1,23 @@ from transformers import HfArgumentParser -from FlagEmbedding import FlagAutoModel, FlagAutoReranker -from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator +from FlagEmbedding.evaluation.msmarco import ( + MSMARCOEvalArgs, MSMARCOEvalModelArgs, + MSMARCOEvalRunner +) -from utils.arguments import MSMARCOEvalArgs -from utils.data_loader import MSMARCODataLoader +parser = HfArgumentParser(( + MSMARCOEvalArgs, + MSMARCOEvalModelArgs +)) +eval_args, model_args = parser.parse_args_into_dataclasses() +eval_args: MSMARCOEvalArgs +model_args: MSMARCOEvalModelArgs -def get_models(model_args: AbsModelArgs): - retriever = FlagAutoModel.from_finetuned( - model_name_or_path=model_args.embedder_name_or_path, - normalize_embeddings=model_args.normalize_embeddings, - use_fp16=model_args.use_fp16, - query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, - query_instruction_format=model_args.query_instruction_format_for_retrieval, - devices=model_args.devices, - examples_for_task=model_args.examples_for_task, - examples_instruction_format=model_args.examples_instruction_format, - trust_remote_code=model_args.trust_remote_code, - cache_dir=model_args.cache_dir, - batch_size=model_args.retriever_batch_size, - query_max_length=model_args.retriever_query_max_length, - passage_max_length=model_args.retriever_passage_max_length, - ) - reranker = None - if model_args.reranker_name_or_path is not None: - reranker = FlagAutoReranker.from_finetuned( - model_name_or_path=model_args.reranker_name_or_path, - peft_path=model_args.reranker_peft_path, - use_fp16=model_args.use_fp16, - use_bf16=model_args.use_bf16, - query_instruction_for_rerank=model_args.query_instruction_for_rerank, - query_instruction_format=model_args.query_instruction_format_for_rerank, - passage_instruction_for_rerank=model_args.passage_instruction_for_rerank, - passage_instruction_format=model_args.passage_instruction_format_for_rerank, - cache_dir=model_args.cache_dir, - trust_remote_code=model_args.trust_remote_code, - devices=model_args.devices, - normalize=model_args.normalize, - prompt=model_args.prompt, - cutoff_layers=model_args.cutoff_layers, - compress_layers=model_args.compress_layers, - compress_ratio=model_args.compress_ratio, - batch_size=model_args.reranker_batch_size, - query_max_length=model_args.reranker_query_max_length, - max_length=model_args.reranker_max_length, - ) - return retriever, reranker +runner = MSMARCOEvalRunner( + eval_args=eval_args, + model_args=model_args +) - -def main(): - parser = HfArgumentParser([AbsModelArgs, MSMARCOEvalArgs]) - model_args, eval_args = parser.parse_args_into_dataclasses() - model_args: AbsModelArgs - eval_args: MSMARCOEvalArgs - - retriever, reranker = get_models(model_args) - - data_loader = MSMARCODataLoader( - dataset_dir = eval_args.dataset_dir, - cache_dir = eval_args.cache_path, - text_type = eval_args.text_type - ) - - evaluation = AbsEvaluator( - data_loader=data_loader, - overwrite=eval_args.overwrite, - ) - - retriever = AbsEmbedder( - retriever, - search_top_k=eval_args.search_top_k, - ) - - if reranker is not None: - reranker = AbsReranker( - reranker, - rerank_top_k=eval_args.rerank_top_k, - ) - else: - reranker = None - - evaluation( - splits=eval_args.splits.split(), - search_results_save_dir=eval_args.output_dir, - retriever=retriever, - reranker=reranker, - corpus_embd_save_dir=eval_args.corpus_embd_save_dir, - ) - - -if __name__ == "__main__": - main() \ No newline at end of file +runner.run() diff --git a/FlagEmbedding/evaluation/msmarco/__main__2.py b/FlagEmbedding/evaluation/msmarco/__main__2.py new file mode 100644 index 0000000..32f2052 --- /dev/null +++ b/FlagEmbedding/evaluation/msmarco/__main__2.py @@ -0,0 +1,95 @@ +from transformers import HfArgumentParser + +from FlagEmbedding import FlagAutoModel, FlagAutoReranker +from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator + + +from utils.arguments import MSMARCOEvalArgs +from utils.data_loader import MSMARCODataLoader + + +def get_models(model_args: AbsModelArgs): + retriever = FlagAutoModel.from_finetuned( + model_name_or_path=model_args.embedder_name_or_path, + normalize_embeddings=model_args.normalize_embeddings, + use_fp16=model_args.use_fp16, + query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, + query_instruction_format=model_args.query_instruction_format_for_retrieval, + devices=model_args.devices, + examples_for_task=model_args.examples_for_task, + examples_instruction_format=model_args.examples_instruction_format, + trust_remote_code=model_args.trust_remote_code, + cache_dir=model_args.cache_dir, + batch_size=model_args.retriever_batch_size, + query_max_length=model_args.retriever_query_max_length, + passage_max_length=model_args.retriever_passage_max_length, + ) + reranker = None + if model_args.reranker_name_or_path is not None: + reranker = FlagAutoReranker.from_finetuned( + model_name_or_path=model_args.reranker_name_or_path, + peft_path=model_args.reranker_peft_path, + use_fp16=model_args.use_fp16, + use_bf16=model_args.use_bf16, + query_instruction_for_rerank=model_args.query_instruction_for_rerank, + query_instruction_format=model_args.query_instruction_format_for_rerank, + passage_instruction_for_rerank=model_args.passage_instruction_for_rerank, + passage_instruction_format=model_args.passage_instruction_format_for_rerank, + cache_dir=model_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + devices=model_args.devices, + normalize=model_args.normalize, + prompt=model_args.prompt, + cutoff_layers=model_args.cutoff_layers, + compress_layers=model_args.compress_layers, + compress_ratio=model_args.compress_ratio, + batch_size=model_args.reranker_batch_size, + query_max_length=model_args.reranker_query_max_length, + max_length=model_args.reranker_max_length, + ) + return retriever, reranker + + +def main(): + parser = HfArgumentParser([AbsModelArgs, MSMARCOEvalArgs]) + model_args, eval_args = parser.parse_args_into_dataclasses() + model_args: AbsModelArgs + eval_args: MSMARCOEvalArgs + + retriever, reranker = get_models(model_args) + + data_loader = MSMARCODataLoader( + dataset_dir = eval_args.dataset_dir, + cache_dir = eval_args.cache_path, + text_type = eval_args.text_type + ) + + evaluation = AbsEvaluator( + data_loader=data_loader, + overwrite=eval_args.overwrite, + ) + + retriever = AbsEmbedder( + retriever, + search_top_k=eval_args.search_top_k, + ) + + if reranker is not None: + reranker = AbsReranker( + reranker, + rerank_top_k=eval_args.rerank_top_k, + ) + else: + reranker = None + + evaluation( + splits=eval_args.splits.split(), + search_results_save_dir=eval_args.output_dir, + retriever=retriever, + reranker=reranker, + corpus_embd_save_dir=eval_args.corpus_embd_save_dir, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/FlagEmbedding/evaluation/msmarco/data_loader.py b/FlagEmbedding/evaluation/msmarco/data_loader.py new file mode 100644 index 0000000..4c89689 --- /dev/null +++ b/FlagEmbedding/evaluation/msmarco/data_loader.py @@ -0,0 +1,235 @@ +import os +import json +import logging +import datasets +from tqdm import tqdm +from typing import List, Optional + +from FlagEmbedding.abc.evaluation import AbsEvalDataLoader + +logger = logging.getLogger(__name__) + + +class MSMARCOEvalDataLoader(AbsEvalDataLoader): + def available_dataset_names(self) -> List[str]: + return ["passage", "document"] + + def available_splits(self, dataset_name: str = None) -> List[str]: + return ["dev", "dl19", "dl20"] + + def _load_remote_corpus( + self, + dataset_name: str, + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + if dataset_name == 'passage': + corpus = datasets.load_dataset( + 'Tevatron/msmarco-passage-corpus', + 'default', + trust_remote_code=True, + cache_dir=self.cache_dir, + )['train'] + else: + corpus = datasets.load_dataset( + 'irds/msmarco-document', + 'docs', + trust_remote_code=True, + cache_dir=self.cache_dir, + ) + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, "corpus.jsonl") + corpus_dict = {} + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(corpus, desc="Loading and Saving corpus"): + _data = { + "id": data["docid"], + "title": data["title"], + "text": data.get("text", data.get("body", "")) + } + corpus_dict[data["docid"]] = { + "title": data["title"], + "text": data.get("text", data.get("body", "")) + } + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}") + else: + corpus_dict = {data["docid"]: {"title": data["title"], "text": data.get("text", data.get("body", ""))} for data in tqdm(corpus, desc="Loading corpus")} + return datasets.DatasetDict(corpus_dict) + + def _download_gz_file(self, download_url: str, save_dir: str): + cmd = f"wget -P {save_dir} {download_url}" + os.system(cmd) + save_path = os.path.join(save_dir, download_url.split('/')[-1]) + + if not os.path.exists(save_path) or os.path.getsize(save_path) == 0: + raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}") + else: + logger.info(f"Downloaded file from {download_url} to {save_path}") + cmd = f"gzip -d {save_path}" + os.system(cmd) + new_save_path = save_path.replace(".gz", "") + logger.info(f"Unzip file from {save_path} to {new_save_path}") + return new_save_path + + def _download_file(self, download_url: str, save_dir: str): + cmd = f"wget -P {save_dir} {download_url}" + os.system(cmd) + save_path = os.path.join(save_dir, download_url.split('/')[-1]) + + if not os.path.exists(save_path) or os.path.getsize(save_path) == 0: + raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}") + else: + logger.info(f"Downloaded file from {download_url} to {save_path}") + + def _load_remote_qrels( + self, + dataset_name: Optional[str] = None, + split: str = 'dev', + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + if dataset_name == 'passage': + if split == 'dev': + qrels = datasets.load_dataset( + 'BeIR/msmarco-qrels', + split='validation', + trust_remote_code=True, + cache_dir=self.cache_dir, + ) + qrels_download_url = None + elif split == 'dl19': + qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-pass.txt" + else: + qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-pass.txt" + else: + if split == 'dev': + qrels_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz" + elif split == 'dl19': + qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-docs.txt" + else: + qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-docs.txt" + + if qrels_download_url is not None: + qrels_save_path = self._download_file(qrels_download_url, self.cache_dir) + else: + qrels_save_path = None + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{split}_qrels.jsonl") + qrels_dict = {} + if qrels_save_path is not None: + with open(save_path, "w", encoding="utf-8") as f1: + with open(qrels_save_path, "r", encoding="utf-8") as f2: + for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"): + qid, _, docid, rel = line.strip().split() + qid, docid, rel = str(qid), str(docid), int(rel) + _data = { + "qid": qid, + "docid": docid, + "relevance": rel + } + if qid not in qrels_dict: + qrels_dict[qid] = {} + qrels_dict[qid][docid] = rel + f1.write(json.dumps(_data, ensure_ascii=False) + "\n") + else: + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(queries, desc="Loading and Saving qrels"): + qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score']) + _data = { + "id": qid, + "text": query + } + _data = { + "qid": qid, + "docid": docid, + "relevance": rel + } + if qid not in qrels_dict: + qrels_dict[qid] = {} + qrels_dict[qid][docid] = rel + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}") + else: + qrels_dict = {} + if qrels_save_path is None: + with open(qrels_save_path, "r", encoding="utf-8") as f: + for line in tqdm(f.readlines(), desc="Loading qrels"): + qid, _, docid, rel = line.strip().split("\t") + qid, docid, rel = int(qid), int(docid), int(rel) + if qid not in qrels_dict: + qrels_dict[qid] = {} + qrels_dict[qid][docid] = rel + else: + for data in tqdm(queries, desc="Loading queries"): + qid, docid, rel = str(qid), str(docid), int(rel) + if qid not in qrels_dict: + qrels_dict[qid] = {} + qrels_dict[qid][docid] = rel + return datasets.DatasetDict(qrels_dict) + + def _load_remote_queries( + self, + dataset_name: Optional[str] = None, + split: str = 'test', + save_dir: Optional[str] = None + ) -> datasets.DatasetDict: + if split == 'dev': + if dataset_name == 'passage': + queries = datasets.load_dataset( + 'BeIR/msmarco', + 'queries', + trust_remote_code=True, + cache_dir=self.cache_dir, + )['queries'] + queries_save_path = None + else: + queries_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz" + queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir) + else: + year = split.replace("dl", "") + queries_download_url = f"https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test20{year}-queries.tsv.gz" + queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir) + + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{split}_queries.jsonl") + queries_dict = {} + if queries_save_path is not None: + with open(save_path, "w", encoding="utf-8") as f1: + with open(queries_save_path, "r", encoding="utf-8") as f2: + for line in tqdm(f2.readlines(), desc="Loading and Saving queries"): + qid, query = line.strip().split("\t") + qid = str(qid) + _data = { + "id": qid, + "text": query + } + queries_dict[qid] = query + f1.write(json.dumps(_data, ensure_ascii=False) + "\n") + else: + with open(save_path, "w", encoding="utf-8") as f: + for data in tqdm(queries, desc="Loading and Saving queries"): + qid, query = data['_id'], data['text'] + _data = { + "id": qid, + "text": query + } + queries_dict[qid] = query + f.write(json.dumps(_data, ensure_ascii=False) + "\n") + logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}") + else: + queries_dict = {} + if queries_save_path is not None: + with open(queries_save_path, "r", encoding="utf-8") as f: + for line in tqdm(f.readlines(), desc="Loading queries"): + qid, query = line.strip().split("\t") + qid = str(qid) + queries_dict[qid] = query + else: + for data in tqdm(queries, desc="Loading queries"): + qid, query = data['_id'], data['text'] + queries_dict[qid] = query + return datasets.DatasetDict(queries_dict) \ No newline at end of file diff --git a/FlagEmbedding/evaluation/msmarco/runner.py b/FlagEmbedding/evaluation/msmarco/runner.py new file mode 100644 index 0000000..2bfa891 --- /dev/null +++ b/FlagEmbedding/evaluation/msmarco/runner.py @@ -0,0 +1,14 @@ +from FlagEmbedding.abc.evaluation import AbsEvalRunner + +from .data_loader import MSMARCOEvalDataLoader + + +class MSMARCOEvalRunner(AbsEvalRunner): + def load_data_loader(self) -> MSMARCOEvalDataLoader: + data_loader = MSMARCOEvalDataLoader( + eval_name=self.eval_args.eval_name, + dataset_dir=self.eval_args.dataset_dir, + cache_dir=self.eval_args.cache_path, + token=self.eval_args.token + ) + return data_loader diff --git a/examples/evaluation/miracl/eval_miracl.sh b/examples/evaluation/miracl/eval_miracl.sh index 9b282b2..568e4c1 100644 --- a/examples/evaluation/miracl/eval_miracl.sh +++ b/examples/evaluation/miracl/eval_miracl.sh @@ -2,7 +2,9 @@ if [ -z "$HF_HUB_CACHE" ]; then export HF_HUB_CACHE="$HOME/.cache/huggingface/hub" fi -dataset_names="bn hi sw te th yo" +# dataset_names="bn hi sw te th yo" +dataset_names="sw" +HF_HUB_CACHE="/share/shared_models" eval_args="\ --eval_name miracl \ @@ -23,7 +25,7 @@ eval_args="\ model_args="\ --embedder_name_or_path BAAI/bge-m3 \ --reranker_name_or_path BAAI/bge-reranker-v2-m3 \ - --devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \ + --devices cuda:0 \ --cache_dir $HF_HUB_CACHE \ --reranker_max_length 1024 \ " diff --git a/examples/evaluation/msmarco/eval_msmarco.sh b/examples/evaluation/msmarco/eval_msmarco.sh new file mode 100644 index 0000000..b9f521b --- /dev/null +++ b/examples/evaluation/msmarco/eval_msmarco.sh @@ -0,0 +1,39 @@ +if [ -z "$HF_HUB_CACHE" ]; then + export HF_HUB_CACHE="$HOME/.cache/huggingface/hub" +fi + +HF_HUB_CACHE="/share/shared_models" + +dataset_names="passage" + +eval_args="\ + --eval_name msmarco \ + --dataset_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco \ + --dataset_names $dataset_names \ + --splits dl19 \ + --corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco/corpus_embd \ + --output_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco/search_results \ + --search_top_k 1000 --rerank_top_k 100 \ + --cache_path $HF_HUB_CACHE \ + --overwrite False \ + --k_values 10 100 \ + --eval_output_method markdown \ + --eval_output_path /share/chaofan/code/FlagEmbedding_update/data/msmarco/msmarco_eval_results.md \ + --eval_metrics ndcg_at_10 recall_at_100 \ +" + +model_args="\ + --embedder_name_or_path BAAI/bge-m3 \ + --reranker_name_or_path BAAI/bge-reranker-v2-m3 \ + --devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \ + --cache_dir $HF_HUB_CACHE \ + --reranker_max_length 1024 \ +" + +cmd="python -m FlagEmbedding.evaluation.msmarco \ + $eval_args \ + $model_args \ +" + +echo $cmd +eval $cmd diff --git a/setup.py b/setup.py index 63b38dc..1e63975 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ setup( 'beir', 'deepspeed', 'flash-attn==2.5.6', - 'mteb==1.15.0' + 'mteb==1.15.0', + 'ir-datasets' ], )