mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
fix eval reranker
This commit is contained in:
parent
2d3ab29db9
commit
47e42fcbc6
@ -133,7 +133,7 @@ class AbsEvalDataLoader(ABC):
|
||||
else:
|
||||
corpus_data = datasets.load_dataset('json', data_files=corpus_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
corpus = {e['id']: {'text': e['text']} for e in corpus_data}
|
||||
corpus = {e['id']: {'title', e.get('title', ""), 'text': e['text']} for e in corpus_data}
|
||||
return datasets.DatasetDict(corpus)
|
||||
|
||||
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
|
@ -99,9 +99,6 @@ class EvalDenseRetriever(EvalRetriever):
|
||||
corpus_emb = np.load(os.path.join(corpus_embd_save_dir, "doc.npy"))
|
||||
else:
|
||||
corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs)
|
||||
if corpus_embd_save_dir is not None:
|
||||
os.makedirs(corpus_embd_save_dir, exist_ok=True)
|
||||
np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb)
|
||||
else:
|
||||
corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs)
|
||||
|
||||
@ -112,6 +109,10 @@ class EvalDenseRetriever(EvalRetriever):
|
||||
corpus_emb = corpus_emb["dense_vecs"]
|
||||
if isinstance(queries_emb, dict):
|
||||
queries_emb = queries_emb["dense_vecs"]
|
||||
|
||||
if corpus_embd_save_dir is not None and not os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")):
|
||||
os.makedirs(corpus_embd_save_dir, exist_ok=True)
|
||||
np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb)
|
||||
|
||||
faiss_index = index(corpus_embeddings=corpus_emb)
|
||||
all_scores, all_indices = search(query_embeddings=queries_emb, faiss_index=faiss_index, k=self.search_top_k)
|
||||
@ -174,8 +175,11 @@ class EvalReranker:
|
||||
)
|
||||
# generate sentence pairs
|
||||
sentence_pairs = []
|
||||
pairs = []
|
||||
for qid in search_results:
|
||||
for docid in search_results[qid]:
|
||||
print(corpus[docid])
|
||||
sys.eixt()
|
||||
sentence_pairs.append(
|
||||
{
|
||||
"qid": qid,
|
||||
@ -185,9 +189,19 @@ class EvalReranker:
|
||||
else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip(),
|
||||
}
|
||||
)
|
||||
pairs = [(e["query"], e["doc"]) for e in sentence_pairs]
|
||||
pairs.append(
|
||||
(
|
||||
queries[qid],
|
||||
corpus[docid]["text"] if "title" not in corpus[docid]
|
||||
else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip()
|
||||
)
|
||||
)
|
||||
# pairs = [(e["query"], e["doc"]) for e in sentence_pairs]
|
||||
# compute scores
|
||||
scores = self.reranker.compute_score(pairs)
|
||||
# print(scores)
|
||||
# print(self.reranker.compute_score([pairs[0]]))
|
||||
# print(pairs[0], sentence_pairs[0])
|
||||
for i, score in enumerate(scores):
|
||||
sentence_pairs[i]["score"] = float(score)
|
||||
# rerank
|
||||
|
@ -126,6 +126,8 @@ class AbsReranker(ABC):
|
||||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
if isinstance(sentence_pairs[0], str):
|
||||
sentence_pairs = [sentence_pairs]
|
||||
sentence_pairs = self.get_detailed_inputs(sentence_pairs)
|
||||
|
||||
if isinstance(sentence_pairs, str) or len(self.target_devices) == 1:
|
||||
|
@ -1,13 +1,13 @@
|
||||
python __main__.py \
|
||||
--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR \
|
||||
--embedder_name_or_path BAAI/bge-large-en-v1.5 \
|
||||
--reranker_name_or_path BAAI/bge-reranker-large \
|
||||
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
|
||||
--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: " \
|
||||
--use_fp16 True \
|
||||
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
|
||||
--cache_dir /share/shared_models \
|
||||
--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR_passage_embds \
|
||||
--reranker_max_length 512 \
|
||||
--dataset_names arguana \
|
||||
--use_special_instructions True
|
||||
--reranker_max_length 1024 \
|
||||
--dataset_names trec-covid webis-touche2020 \
|
||||
--use_special_instructions False
|
||||
|
||||
|
14
FlagEmbedding/evaluation/msmarco/__init__.py
Normal file
14
FlagEmbedding/evaluation/msmarco/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalArgs as MSMARCOEvalArgs,
|
||||
AbsEvalModelArgs as MSMARCOEvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import MSMARCOEvalDataLoader
|
||||
from .runner import MSMARCOEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"MSMARCOEvalArgs",
|
||||
"MSMARCOEvalModelArgs",
|
||||
"MSMARCOEvalRunner",
|
||||
"MSMARCOEvalDataLoader",
|
||||
]
|
@ -1,95 +1,23 @@
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
|
||||
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
|
||||
from FlagEmbedding.evaluation.msmarco import (
|
||||
MSMARCOEvalArgs, MSMARCOEvalModelArgs,
|
||||
MSMARCOEvalRunner
|
||||
)
|
||||
|
||||
|
||||
from utils.arguments import MSMARCOEvalArgs
|
||||
from utils.data_loader import MSMARCODataLoader
|
||||
parser = HfArgumentParser((
|
||||
MSMARCOEvalArgs,
|
||||
MSMARCOEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: MSMARCOEvalArgs
|
||||
model_args: MSMARCOEvalModelArgs
|
||||
|
||||
def get_models(model_args: AbsModelArgs):
|
||||
retriever = FlagAutoModel.from_finetuned(
|
||||
model_name_or_path=model_args.embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.retriever_batch_size,
|
||||
query_max_length=model_args.retriever_query_max_length,
|
||||
passage_max_length=model_args.retriever_passage_max_length,
|
||||
)
|
||||
reranker = None
|
||||
if model_args.reranker_name_or_path is not None:
|
||||
reranker = FlagAutoReranker.from_finetuned(
|
||||
model_name_or_path=model_args.reranker_name_or_path,
|
||||
peft_path=model_args.reranker_peft_path,
|
||||
use_fp16=model_args.use_fp16,
|
||||
use_bf16=model_args.use_bf16,
|
||||
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
|
||||
query_instruction_format=model_args.query_instruction_format_for_rerank,
|
||||
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
|
||||
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
devices=model_args.devices,
|
||||
normalize=model_args.normalize,
|
||||
prompt=model_args.prompt,
|
||||
cutoff_layers=model_args.cutoff_layers,
|
||||
compress_layers=model_args.compress_layers,
|
||||
compress_ratio=model_args.compress_ratio,
|
||||
batch_size=model_args.reranker_batch_size,
|
||||
query_max_length=model_args.reranker_query_max_length,
|
||||
max_length=model_args.reranker_max_length,
|
||||
)
|
||||
return retriever, reranker
|
||||
runner = MSMARCOEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser([AbsModelArgs, MSMARCOEvalArgs])
|
||||
model_args, eval_args = parser.parse_args_into_dataclasses()
|
||||
model_args: AbsModelArgs
|
||||
eval_args: MSMARCOEvalArgs
|
||||
|
||||
retriever, reranker = get_models(model_args)
|
||||
|
||||
data_loader = MSMARCODataLoader(
|
||||
dataset_dir = eval_args.dataset_dir,
|
||||
cache_dir = eval_args.cache_path,
|
||||
text_type = eval_args.text_type
|
||||
)
|
||||
|
||||
evaluation = AbsEvaluator(
|
||||
data_loader=data_loader,
|
||||
overwrite=eval_args.overwrite,
|
||||
)
|
||||
|
||||
retriever = AbsEmbedder(
|
||||
retriever,
|
||||
search_top_k=eval_args.search_top_k,
|
||||
)
|
||||
|
||||
if reranker is not None:
|
||||
reranker = AbsReranker(
|
||||
reranker,
|
||||
rerank_top_k=eval_args.rerank_top_k,
|
||||
)
|
||||
else:
|
||||
reranker = None
|
||||
|
||||
evaluation(
|
||||
splits=eval_args.splits.split(),
|
||||
search_results_save_dir=eval_args.output_dir,
|
||||
retriever=retriever,
|
||||
reranker=reranker,
|
||||
corpus_embd_save_dir=eval_args.corpus_embd_save_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
runner.run()
|
||||
|
95
FlagEmbedding/evaluation/msmarco/__main__2.py
Normal file
95
FlagEmbedding/evaluation/msmarco/__main__2.py
Normal file
@ -0,0 +1,95 @@
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
|
||||
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
|
||||
|
||||
|
||||
from utils.arguments import MSMARCOEvalArgs
|
||||
from utils.data_loader import MSMARCODataLoader
|
||||
|
||||
|
||||
def get_models(model_args: AbsModelArgs):
|
||||
retriever = FlagAutoModel.from_finetuned(
|
||||
model_name_or_path=model_args.embedder_name_or_path,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.retriever_batch_size,
|
||||
query_max_length=model_args.retriever_query_max_length,
|
||||
passage_max_length=model_args.retriever_passage_max_length,
|
||||
)
|
||||
reranker = None
|
||||
if model_args.reranker_name_or_path is not None:
|
||||
reranker = FlagAutoReranker.from_finetuned(
|
||||
model_name_or_path=model_args.reranker_name_or_path,
|
||||
peft_path=model_args.reranker_peft_path,
|
||||
use_fp16=model_args.use_fp16,
|
||||
use_bf16=model_args.use_bf16,
|
||||
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
|
||||
query_instruction_format=model_args.query_instruction_format_for_rerank,
|
||||
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
|
||||
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
devices=model_args.devices,
|
||||
normalize=model_args.normalize,
|
||||
prompt=model_args.prompt,
|
||||
cutoff_layers=model_args.cutoff_layers,
|
||||
compress_layers=model_args.compress_layers,
|
||||
compress_ratio=model_args.compress_ratio,
|
||||
batch_size=model_args.reranker_batch_size,
|
||||
query_max_length=model_args.reranker_query_max_length,
|
||||
max_length=model_args.reranker_max_length,
|
||||
)
|
||||
return retriever, reranker
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser([AbsModelArgs, MSMARCOEvalArgs])
|
||||
model_args, eval_args = parser.parse_args_into_dataclasses()
|
||||
model_args: AbsModelArgs
|
||||
eval_args: MSMARCOEvalArgs
|
||||
|
||||
retriever, reranker = get_models(model_args)
|
||||
|
||||
data_loader = MSMARCODataLoader(
|
||||
dataset_dir = eval_args.dataset_dir,
|
||||
cache_dir = eval_args.cache_path,
|
||||
text_type = eval_args.text_type
|
||||
)
|
||||
|
||||
evaluation = AbsEvaluator(
|
||||
data_loader=data_loader,
|
||||
overwrite=eval_args.overwrite,
|
||||
)
|
||||
|
||||
retriever = AbsEmbedder(
|
||||
retriever,
|
||||
search_top_k=eval_args.search_top_k,
|
||||
)
|
||||
|
||||
if reranker is not None:
|
||||
reranker = AbsReranker(
|
||||
reranker,
|
||||
rerank_top_k=eval_args.rerank_top_k,
|
||||
)
|
||||
else:
|
||||
reranker = None
|
||||
|
||||
evaluation(
|
||||
splits=eval_args.splits.split(),
|
||||
search_results_save_dir=eval_args.output_dir,
|
||||
retriever=retriever,
|
||||
reranker=reranker,
|
||||
corpus_embd_save_dir=eval_args.corpus_embd_save_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
235
FlagEmbedding/evaluation/msmarco/data_loader.py
Normal file
235
FlagEmbedding/evaluation/msmarco/data_loader.py
Normal file
@ -0,0 +1,235 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MSMARCOEvalDataLoader(AbsEvalDataLoader):
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
return ["passage", "document"]
|
||||
|
||||
def available_splits(self, dataset_name: str = None) -> List[str]:
|
||||
return ["dev", "dl19", "dl20"]
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: str,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
if dataset_name == 'passage':
|
||||
corpus = datasets.load_dataset(
|
||||
'Tevatron/msmarco-passage-corpus',
|
||||
'default',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)['train']
|
||||
else:
|
||||
corpus = datasets.load_dataset(
|
||||
'irds/msmarco-document',
|
||||
'docs',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(corpus, desc="Loading and Saving corpus"):
|
||||
_data = {
|
||||
"id": data["docid"],
|
||||
"title": data["title"],
|
||||
"text": data.get("text", data.get("body", ""))
|
||||
}
|
||||
corpus_dict[data["docid"]] = {
|
||||
"title": data["title"],
|
||||
"text": data.get("text", data.get("body", ""))
|
||||
}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
|
||||
else:
|
||||
corpus_dict = {data["docid"]: {"title": data["title"], "text": data.get("text", data.get("body", ""))} for data in tqdm(corpus, desc="Loading corpus")}
|
||||
return datasets.DatasetDict(corpus_dict)
|
||||
|
||||
def _download_gz_file(self, download_url: str, save_dir: str):
|
||||
cmd = f"wget -P {save_dir} {download_url}"
|
||||
os.system(cmd)
|
||||
save_path = os.path.join(save_dir, download_url.split('/')[-1])
|
||||
|
||||
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
|
||||
raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}")
|
||||
else:
|
||||
logger.info(f"Downloaded file from {download_url} to {save_path}")
|
||||
cmd = f"gzip -d {save_path}"
|
||||
os.system(cmd)
|
||||
new_save_path = save_path.replace(".gz", "")
|
||||
logger.info(f"Unzip file from {save_path} to {new_save_path}")
|
||||
return new_save_path
|
||||
|
||||
def _download_file(self, download_url: str, save_dir: str):
|
||||
cmd = f"wget -P {save_dir} {download_url}"
|
||||
os.system(cmd)
|
||||
save_path = os.path.join(save_dir, download_url.split('/')[-1])
|
||||
|
||||
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
|
||||
raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}")
|
||||
else:
|
||||
logger.info(f"Downloaded file from {download_url} to {save_path}")
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
if dataset_name == 'passage':
|
||||
if split == 'dev':
|
||||
qrels = datasets.load_dataset(
|
||||
'BeIR/msmarco-qrels',
|
||||
split='validation',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
qrels_download_url = None
|
||||
elif split == 'dl19':
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-pass.txt"
|
||||
else:
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-pass.txt"
|
||||
else:
|
||||
if split == 'dev':
|
||||
qrels_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
|
||||
elif split == 'dl19':
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-docs.txt"
|
||||
else:
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-docs.txt"
|
||||
|
||||
if qrels_download_url is not None:
|
||||
qrels_save_path = self._download_file(qrels_download_url, self.cache_dir)
|
||||
else:
|
||||
qrels_save_path = None
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
if qrels_save_path is not None:
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
|
||||
qid, _, docid, rel = line.strip().split()
|
||||
qid, docid, rel = str(qid), str(docid), int(rel)
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
else:
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(queries, desc="Loading and Saving qrels"):
|
||||
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
if qrels_save_path is None:
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading qrels"):
|
||||
qid, _, docid, rel = line.strip().split("\t")
|
||||
qid, docid, rel = int(qid), int(docid), int(rel)
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
else:
|
||||
for data in tqdm(queries, desc="Loading queries"):
|
||||
qid, docid, rel = str(qid), str(docid), int(rel)
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
return datasets.DatasetDict(qrels_dict)
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
if split == 'dev':
|
||||
if dataset_name == 'passage':
|
||||
queries = datasets.load_dataset(
|
||||
'BeIR/msmarco',
|
||||
'queries',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)['queries']
|
||||
queries_save_path = None
|
||||
else:
|
||||
queries_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
|
||||
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
|
||||
else:
|
||||
year = split.replace("dl", "")
|
||||
queries_download_url = f"https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test20{year}-queries.tsv.gz"
|
||||
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
if queries_save_path is not None:
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
|
||||
qid, query = line.strip().split("\t")
|
||||
qid = str(qid)
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
else:
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(queries, desc="Loading and Saving queries"):
|
||||
qid, query = data['_id'], data['text']
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
if queries_save_path is not None:
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading queries"):
|
||||
qid, query = line.strip().split("\t")
|
||||
qid = str(qid)
|
||||
queries_dict[qid] = query
|
||||
else:
|
||||
for data in tqdm(queries, desc="Loading queries"):
|
||||
qid, query = data['_id'], data['text']
|
||||
queries_dict[qid] = query
|
||||
return datasets.DatasetDict(queries_dict)
|
14
FlagEmbedding/evaluation/msmarco/runner.py
Normal file
14
FlagEmbedding/evaluation/msmarco/runner.py
Normal file
@ -0,0 +1,14 @@
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import MSMARCOEvalDataLoader
|
||||
|
||||
|
||||
class MSMARCOEvalRunner(AbsEvalRunner):
|
||||
def load_data_loader(self) -> MSMARCOEvalDataLoader:
|
||||
data_loader = MSMARCOEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token
|
||||
)
|
||||
return data_loader
|
@ -2,7 +2,9 @@ if [ -z "$HF_HUB_CACHE" ]; then
|
||||
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
|
||||
fi
|
||||
|
||||
dataset_names="bn hi sw te th yo"
|
||||
# dataset_names="bn hi sw te th yo"
|
||||
dataset_names="sw"
|
||||
HF_HUB_CACHE="/share/shared_models"
|
||||
|
||||
eval_args="\
|
||||
--eval_name miracl \
|
||||
@ -23,7 +25,7 @@ eval_args="\
|
||||
model_args="\
|
||||
--embedder_name_or_path BAAI/bge-m3 \
|
||||
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
|
||||
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
|
||||
--devices cuda:0 \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
--reranker_max_length 1024 \
|
||||
"
|
||||
|
39
examples/evaluation/msmarco/eval_msmarco.sh
Normal file
39
examples/evaluation/msmarco/eval_msmarco.sh
Normal file
@ -0,0 +1,39 @@
|
||||
if [ -z "$HF_HUB_CACHE" ]; then
|
||||
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
|
||||
fi
|
||||
|
||||
HF_HUB_CACHE="/share/shared_models"
|
||||
|
||||
dataset_names="passage"
|
||||
|
||||
eval_args="\
|
||||
--eval_name msmarco \
|
||||
--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco \
|
||||
--dataset_names $dataset_names \
|
||||
--splits dl19 \
|
||||
--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco/corpus_embd \
|
||||
--output_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco/search_results \
|
||||
--search_top_k 1000 --rerank_top_k 100 \
|
||||
--cache_path $HF_HUB_CACHE \
|
||||
--overwrite False \
|
||||
--k_values 10 100 \
|
||||
--eval_output_method markdown \
|
||||
--eval_output_path /share/chaofan/code/FlagEmbedding_update/data/msmarco/msmarco_eval_results.md \
|
||||
--eval_metrics ndcg_at_10 recall_at_100 \
|
||||
"
|
||||
|
||||
model_args="\
|
||||
--embedder_name_or_path BAAI/bge-m3 \
|
||||
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
|
||||
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
--reranker_max_length 1024 \
|
||||
"
|
||||
|
||||
cmd="python -m FlagEmbedding.evaluation.msmarco \
|
||||
$eval_args \
|
||||
$model_args \
|
||||
"
|
||||
|
||||
echo $cmd
|
||||
eval $cmd
|
Loading…
x
Reference in New Issue
Block a user