fix eval reranker

This commit is contained in:
cfli 2024-10-27 02:06:12 +08:00
parent 2d3ab29db9
commit 47e42fcbc6
12 changed files with 444 additions and 100 deletions

View File

@ -133,7 +133,7 @@ class AbsEvalDataLoader(ABC):
else:
corpus_data = datasets.load_dataset('json', data_files=corpus_path, cache_dir=self.cache_dir)['train']
corpus = {e['id']: {'text': e['text']} for e in corpus_data}
corpus = {e['id']: {'title', e.get('title', ""), 'text': e['text']} for e in corpus_data}
return datasets.DatasetDict(corpus)
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:

View File

@ -99,9 +99,6 @@ class EvalDenseRetriever(EvalRetriever):
corpus_emb = np.load(os.path.join(corpus_embd_save_dir, "doc.npy"))
else:
corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs)
if corpus_embd_save_dir is not None:
os.makedirs(corpus_embd_save_dir, exist_ok=True)
np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb)
else:
corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs)
@ -112,6 +109,10 @@ class EvalDenseRetriever(EvalRetriever):
corpus_emb = corpus_emb["dense_vecs"]
if isinstance(queries_emb, dict):
queries_emb = queries_emb["dense_vecs"]
if corpus_embd_save_dir is not None and not os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")):
os.makedirs(corpus_embd_save_dir, exist_ok=True)
np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb)
faiss_index = index(corpus_embeddings=corpus_emb)
all_scores, all_indices = search(query_embeddings=queries_emb, faiss_index=faiss_index, k=self.search_top_k)
@ -174,8 +175,11 @@ class EvalReranker:
)
# generate sentence pairs
sentence_pairs = []
pairs = []
for qid in search_results:
for docid in search_results[qid]:
print(corpus[docid])
sys.eixt()
sentence_pairs.append(
{
"qid": qid,
@ -185,9 +189,19 @@ class EvalReranker:
else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip(),
}
)
pairs = [(e["query"], e["doc"]) for e in sentence_pairs]
pairs.append(
(
queries[qid],
corpus[docid]["text"] if "title" not in corpus[docid]
else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip()
)
)
# pairs = [(e["query"], e["doc"]) for e in sentence_pairs]
# compute scores
scores = self.reranker.compute_score(pairs)
# print(scores)
# print(self.reranker.compute_score([pairs[0]]))
# print(pairs[0], sentence_pairs[0])
for i, score in enumerate(scores):
sentence_pairs[i]["score"] = float(score)
# rerank

View File

@ -126,6 +126,8 @@ class AbsReranker(ABC):
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
**kwargs
):
if isinstance(sentence_pairs[0], str):
sentence_pairs = [sentence_pairs]
sentence_pairs = self.get_detailed_inputs(sentence_pairs)
if isinstance(sentence_pairs, str) or len(self.target_devices) == 1:

View File

@ -1,13 +1,13 @@
python __main__.py \
--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR \
--embedder_name_or_path BAAI/bge-large-en-v1.5 \
--reranker_name_or_path BAAI/bge-reranker-large \
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: " \
--use_fp16 True \
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
--cache_dir /share/shared_models \
--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/BEIR_passage_embds \
--reranker_max_length 512 \
--dataset_names arguana \
--use_special_instructions True
--reranker_max_length 1024 \
--dataset_names trec-covid webis-touche2020 \
--use_special_instructions False

View File

@ -0,0 +1,14 @@
from FlagEmbedding.abc.evaluation import (
AbsEvalArgs as MSMARCOEvalArgs,
AbsEvalModelArgs as MSMARCOEvalModelArgs,
)
from .data_loader import MSMARCOEvalDataLoader
from .runner import MSMARCOEvalRunner
__all__ = [
"MSMARCOEvalArgs",
"MSMARCOEvalModelArgs",
"MSMARCOEvalRunner",
"MSMARCOEvalDataLoader",
]

View File

@ -1,95 +1,23 @@
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
from FlagEmbedding.evaluation.msmarco import (
MSMARCOEvalArgs, MSMARCOEvalModelArgs,
MSMARCOEvalRunner
)
from utils.arguments import MSMARCOEvalArgs
from utils.data_loader import MSMARCODataLoader
parser = HfArgumentParser((
MSMARCOEvalArgs,
MSMARCOEvalModelArgs
))
eval_args, model_args = parser.parse_args_into_dataclasses()
eval_args: MSMARCOEvalArgs
model_args: MSMARCOEvalModelArgs
def get_models(model_args: AbsModelArgs):
retriever = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.retriever_batch_size,
query_max_length=model_args.retriever_query_max_length,
passage_max_length=model_args.retriever_passage_max_length,
)
reranker = None
if model_args.reranker_name_or_path is not None:
reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_args.reranker_name_or_path,
peft_path=model_args.reranker_peft_path,
use_fp16=model_args.use_fp16,
use_bf16=model_args.use_bf16,
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
query_instruction_format=model_args.query_instruction_format_for_rerank,
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
devices=model_args.devices,
normalize=model_args.normalize,
prompt=model_args.prompt,
cutoff_layers=model_args.cutoff_layers,
compress_layers=model_args.compress_layers,
compress_ratio=model_args.compress_ratio,
batch_size=model_args.reranker_batch_size,
query_max_length=model_args.reranker_query_max_length,
max_length=model_args.reranker_max_length,
)
return retriever, reranker
runner = MSMARCOEvalRunner(
eval_args=eval_args,
model_args=model_args
)
def main():
parser = HfArgumentParser([AbsModelArgs, MSMARCOEvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: AbsModelArgs
eval_args: MSMARCOEvalArgs
retriever, reranker = get_models(model_args)
data_loader = MSMARCODataLoader(
dataset_dir = eval_args.dataset_dir,
cache_dir = eval_args.cache_path,
text_type = eval_args.text_type
)
evaluation = AbsEvaluator(
data_loader=data_loader,
overwrite=eval_args.overwrite,
)
retriever = AbsEmbedder(
retriever,
search_top_k=eval_args.search_top_k,
)
if reranker is not None:
reranker = AbsReranker(
reranker,
rerank_top_k=eval_args.rerank_top_k,
)
else:
reranker = None
evaluation(
splits=eval_args.splits.split(),
search_results_save_dir=eval_args.output_dir,
retriever=retriever,
reranker=reranker,
corpus_embd_save_dir=eval_args.corpus_embd_save_dir,
)
if __name__ == "__main__":
main()
runner.run()

View File

@ -0,0 +1,95 @@
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel, FlagAutoReranker
from FlagEmbedding.abc.evaluation import AbsModelArgs, AbsEmbedder, AbsReranker, AbsEvaluator
from utils.arguments import MSMARCOEvalArgs
from utils.data_loader import MSMARCODataLoader
def get_models(model_args: AbsModelArgs):
retriever = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.retriever_batch_size,
query_max_length=model_args.retriever_query_max_length,
passage_max_length=model_args.retriever_passage_max_length,
)
reranker = None
if model_args.reranker_name_or_path is not None:
reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_args.reranker_name_or_path,
peft_path=model_args.reranker_peft_path,
use_fp16=model_args.use_fp16,
use_bf16=model_args.use_bf16,
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
query_instruction_format=model_args.query_instruction_format_for_rerank,
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
cache_dir=model_args.cache_dir,
trust_remote_code=model_args.trust_remote_code,
devices=model_args.devices,
normalize=model_args.normalize,
prompt=model_args.prompt,
cutoff_layers=model_args.cutoff_layers,
compress_layers=model_args.compress_layers,
compress_ratio=model_args.compress_ratio,
batch_size=model_args.reranker_batch_size,
query_max_length=model_args.reranker_query_max_length,
max_length=model_args.reranker_max_length,
)
return retriever, reranker
def main():
parser = HfArgumentParser([AbsModelArgs, MSMARCOEvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: AbsModelArgs
eval_args: MSMARCOEvalArgs
retriever, reranker = get_models(model_args)
data_loader = MSMARCODataLoader(
dataset_dir = eval_args.dataset_dir,
cache_dir = eval_args.cache_path,
text_type = eval_args.text_type
)
evaluation = AbsEvaluator(
data_loader=data_loader,
overwrite=eval_args.overwrite,
)
retriever = AbsEmbedder(
retriever,
search_top_k=eval_args.search_top_k,
)
if reranker is not None:
reranker = AbsReranker(
reranker,
rerank_top_k=eval_args.rerank_top_k,
)
else:
reranker = None
evaluation(
splits=eval_args.splits.split(),
search_results_save_dir=eval_args.output_dir,
retriever=retriever,
reranker=reranker,
corpus_embd_save_dir=eval_args.corpus_embd_save_dir,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,235 @@
import os
import json
import logging
import datasets
from tqdm import tqdm
from typing import List, Optional
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
logger = logging.getLogger(__name__)
class MSMARCOEvalDataLoader(AbsEvalDataLoader):
def available_dataset_names(self) -> List[str]:
return ["passage", "document"]
def available_splits(self, dataset_name: str = None) -> List[str]:
return ["dev", "dl19", "dl20"]
def _load_remote_corpus(
self,
dataset_name: str,
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
if dataset_name == 'passage':
corpus = datasets.load_dataset(
'Tevatron/msmarco-passage-corpus',
'default',
trust_remote_code=True,
cache_dir=self.cache_dir,
)['train']
else:
corpus = datasets.load_dataset(
'irds/msmarco-document',
'docs',
trust_remote_code=True,
cache_dir=self.cache_dir,
)
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "corpus.jsonl")
corpus_dict = {}
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(corpus, desc="Loading and Saving corpus"):
_data = {
"id": data["docid"],
"title": data["title"],
"text": data.get("text", data.get("body", ""))
}
corpus_dict[data["docid"]] = {
"title": data["title"],
"text": data.get("text", data.get("body", ""))
}
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
else:
corpus_dict = {data["docid"]: {"title": data["title"], "text": data.get("text", data.get("body", ""))} for data in tqdm(corpus, desc="Loading corpus")}
return datasets.DatasetDict(corpus_dict)
def _download_gz_file(self, download_url: str, save_dir: str):
cmd = f"wget -P {save_dir} {download_url}"
os.system(cmd)
save_path = os.path.join(save_dir, download_url.split('/')[-1])
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}")
else:
logger.info(f"Downloaded file from {download_url} to {save_path}")
cmd = f"gzip -d {save_path}"
os.system(cmd)
new_save_path = save_path.replace(".gz", "")
logger.info(f"Unzip file from {save_path} to {new_save_path}")
return new_save_path
def _download_file(self, download_url: str, save_dir: str):
cmd = f"wget -P {save_dir} {download_url}"
os.system(cmd)
save_path = os.path.join(save_dir, download_url.split('/')[-1])
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}")
else:
logger.info(f"Downloaded file from {download_url} to {save_path}")
def _load_remote_qrels(
self,
dataset_name: Optional[str] = None,
split: str = 'dev',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
if dataset_name == 'passage':
if split == 'dev':
qrels = datasets.load_dataset(
'BeIR/msmarco-qrels',
split='validation',
trust_remote_code=True,
cache_dir=self.cache_dir,
)
qrels_download_url = None
elif split == 'dl19':
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-pass.txt"
else:
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-pass.txt"
else:
if split == 'dev':
qrels_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
elif split == 'dl19':
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-docs.txt"
else:
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-docs.txt"
if qrels_download_url is not None:
qrels_save_path = self._download_file(qrels_download_url, self.cache_dir)
else:
qrels_save_path = None
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
qrels_dict = {}
if qrels_save_path is not None:
with open(save_path, "w", encoding="utf-8") as f1:
with open(qrels_save_path, "r", encoding="utf-8") as f2:
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
qid, _, docid, rel = line.strip().split()
qid, docid, rel = str(qid), str(docid), int(rel)
_data = {
"qid": qid,
"docid": docid,
"relevance": rel
}
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
else:
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(queries, desc="Loading and Saving qrels"):
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
_data = {
"id": qid,
"text": query
}
_data = {
"qid": qid,
"docid": docid,
"relevance": rel
}
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
else:
qrels_dict = {}
if qrels_save_path is None:
with open(qrels_save_path, "r", encoding="utf-8") as f:
for line in tqdm(f.readlines(), desc="Loading qrels"):
qid, _, docid, rel = line.strip().split("\t")
qid, docid, rel = int(qid), int(docid), int(rel)
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
else:
for data in tqdm(queries, desc="Loading queries"):
qid, docid, rel = str(qid), str(docid), int(rel)
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
return datasets.DatasetDict(qrels_dict)
def _load_remote_queries(
self,
dataset_name: Optional[str] = None,
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
if split == 'dev':
if dataset_name == 'passage':
queries = datasets.load_dataset(
'BeIR/msmarco',
'queries',
trust_remote_code=True,
cache_dir=self.cache_dir,
)['queries']
queries_save_path = None
else:
queries_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
else:
year = split.replace("dl", "")
queries_download_url = f"https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test20{year}-queries.tsv.gz"
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
queries_dict = {}
if queries_save_path is not None:
with open(save_path, "w", encoding="utf-8") as f1:
with open(queries_save_path, "r", encoding="utf-8") as f2:
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
qid, query = line.strip().split("\t")
qid = str(qid)
_data = {
"id": qid,
"text": query
}
queries_dict[qid] = query
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
else:
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(queries, desc="Loading and Saving queries"):
qid, query = data['_id'], data['text']
_data = {
"id": qid,
"text": query
}
queries_dict[qid] = query
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
else:
queries_dict = {}
if queries_save_path is not None:
with open(queries_save_path, "r", encoding="utf-8") as f:
for line in tqdm(f.readlines(), desc="Loading queries"):
qid, query = line.strip().split("\t")
qid = str(qid)
queries_dict[qid] = query
else:
for data in tqdm(queries, desc="Loading queries"):
qid, query = data['_id'], data['text']
queries_dict[qid] = query
return datasets.DatasetDict(queries_dict)

View File

@ -0,0 +1,14 @@
from FlagEmbedding.abc.evaluation import AbsEvalRunner
from .data_loader import MSMARCOEvalDataLoader
class MSMARCOEvalRunner(AbsEvalRunner):
def load_data_loader(self) -> MSMARCOEvalDataLoader:
data_loader = MSMARCOEvalDataLoader(
eval_name=self.eval_args.eval_name,
dataset_dir=self.eval_args.dataset_dir,
cache_dir=self.eval_args.cache_path,
token=self.eval_args.token
)
return data_loader

View File

@ -2,7 +2,9 @@ if [ -z "$HF_HUB_CACHE" ]; then
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
fi
dataset_names="bn hi sw te th yo"
# dataset_names="bn hi sw te th yo"
dataset_names="sw"
HF_HUB_CACHE="/share/shared_models"
eval_args="\
--eval_name miracl \
@ -23,7 +25,7 @@ eval_args="\
model_args="\
--embedder_name_or_path BAAI/bge-m3 \
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
--devices cuda:0 \
--cache_dir $HF_HUB_CACHE \
--reranker_max_length 1024 \
"

View File

@ -0,0 +1,39 @@
if [ -z "$HF_HUB_CACHE" ]; then
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
fi
HF_HUB_CACHE="/share/shared_models"
dataset_names="passage"
eval_args="\
--eval_name msmarco \
--dataset_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco \
--dataset_names $dataset_names \
--splits dl19 \
--corpus_embd_save_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco/corpus_embd \
--output_dir /share/chaofan/code/FlagEmbedding_update/data/msmarco/search_results \
--search_top_k 1000 --rerank_top_k 100 \
--cache_path $HF_HUB_CACHE \
--overwrite False \
--k_values 10 100 \
--eval_output_method markdown \
--eval_output_path /share/chaofan/code/FlagEmbedding_update/data/msmarco/msmarco_eval_results.md \
--eval_metrics ndcg_at_10 recall_at_100 \
"
model_args="\
--embedder_name_or_path BAAI/bge-m3 \
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
--devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \
--cache_dir $HF_HUB_CACHE \
--reranker_max_length 1024 \
"
cmd="python -m FlagEmbedding.evaluation.msmarco \
$eval_args \
$model_args \
"
echo $cmd
eval $cmd

View File

@ -23,6 +23,7 @@ setup(
'beir',
'deepspeed',
'flash-attn==2.5.6',
'mteb==1.15.0'
'mteb==1.15.0',
'ir-datasets'
],
)