mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
164 lines
5.3 KiB
Python
164 lines
5.3 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
import datasets
|
|
from typing import List
|
|
from accelerate import Accelerator
|
|
from transformers import HfArgumentParser
|
|
from dataclasses import dataclass, field, asdict
|
|
|
|
from src.retrieval import (
|
|
RetrievalArgs,
|
|
Retriever,
|
|
RetrievalDataset,
|
|
RetrievalMetric,
|
|
TASK_CONFIG,
|
|
)
|
|
from src.utils.util import makedirs, FileLogger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args(RetrievalArgs):
|
|
eval_data: str = field(
|
|
default=None,
|
|
metadata={'help': 'Query jsonl.'}
|
|
)
|
|
output_dir: str = field(
|
|
default="data/outputs/",
|
|
)
|
|
corpus: str = field(
|
|
default=None,
|
|
metadata={'help': 'Corpus path for retrieval.'}
|
|
)
|
|
key_template: str = field(
|
|
default="{title} {text}",
|
|
metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
|
|
)
|
|
log_path: str = field(
|
|
default="data/results/performance.log",
|
|
metadata={'help': 'Path to the file for logging.'}
|
|
)
|
|
|
|
|
|
def main(args, accelerator=None, log=True):
|
|
if accelerator is None:
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
|
|
with accelerator.main_process_first():
|
|
config = TASK_CONFIG[args.version]
|
|
instruction = config["instruction"]
|
|
|
|
# we should get the evaluation task before specifying instruction
|
|
# NOTE: only dense retrieval needs instruction
|
|
if args.eval_data is not None and args.add_instruction and args.retrieval_method == "dense":
|
|
raw_eval_dataset = datasets.load_dataset('json', data_files=args.eval_data, split='train', cache_dir=args.dataset_cache_dir)
|
|
eval_task = raw_eval_dataset[0]["task"]
|
|
else:
|
|
eval_task = None
|
|
|
|
eval_dataset = RetrievalDataset.prepare_eval_dataset(
|
|
data_file=args.eval_data,
|
|
cache_dir=args.dataset_cache_dir,
|
|
instruction=instruction[eval_task] if eval_task is not None else None,
|
|
)
|
|
corpus = RetrievalDataset.prepare_corpus(
|
|
data_file=args.corpus,
|
|
key_template=args.key_template,
|
|
cache_dir=args.dataset_cache_dir,
|
|
instruction=instruction[eval_task] if eval_task is not None else None
|
|
)
|
|
|
|
result_path = RetrievalMetric._get_save_path(args.eval_data, args.output_dir, field="result", save_name=args.save_name)
|
|
|
|
if args.load_result:
|
|
query_ids, preds = RetrievalMetric._load_result(result_path)
|
|
|
|
else:
|
|
retriever = Retriever(
|
|
retrieval_method=args.retrieval_method,
|
|
# for dense retriever
|
|
query_encoder=args.query_encoder,
|
|
key_encoder=args.key_encoder,
|
|
pooling_method=args.pooling_method,
|
|
dense_metric=args.dense_metric,
|
|
query_max_length=args.query_max_length,
|
|
key_max_length=args.key_max_length,
|
|
tie_encoders=args.tie_encoders,
|
|
truncation_side=args.truncation_side,
|
|
cache_dir=args.model_cache_dir,
|
|
dtype=args.dtype,
|
|
accelerator=accelerator,
|
|
# for bm25 retriever
|
|
anserini_dir=args.anserini_dir,
|
|
k1=args.k1,
|
|
b=args.b
|
|
)
|
|
|
|
retriever.index(
|
|
corpus,
|
|
output_dir=args.output_dir,
|
|
# for dense retriever
|
|
embedding_name=args.embedding_name,
|
|
index_factory=args.faiss_index_factory,
|
|
load_encode=args.load_encode,
|
|
save_encode=args.save_encode,
|
|
load_index=args.load_index,
|
|
save_index=args.save_index,
|
|
batch_size=args.batch_size,
|
|
# for bm25 retriever
|
|
threads=args.threads,
|
|
language=args.language,
|
|
storeDocvectors=args.storeDocvectors,
|
|
load_collection=args.load_collection,
|
|
)
|
|
|
|
query_ids, preds = retriever.search(
|
|
eval_dataset=eval_dataset,
|
|
hits=args.hits,
|
|
# for dense retriever
|
|
batch_size=args.batch_size,
|
|
)
|
|
|
|
del retriever
|
|
torch.cuda.empty_cache()
|
|
|
|
if args.save_result and accelerator.process_index == 0:
|
|
RetrievalMetric._save_result(query_ids, preds, result_path)
|
|
|
|
if accelerator.process_index == 0:
|
|
# NOTE: this corpus is for computing metrics, where no instruction is given
|
|
no_instruction_corpus = RetrievalDataset.prepare_corpus(
|
|
data_file=args.corpus,
|
|
key_template=args.key_template,
|
|
cache_dir=args.dataset_cache_dir,
|
|
)
|
|
|
|
metrics = RetrievalMetric.get_metric_fn(
|
|
args.metrics,
|
|
cutoffs=args.cutoffs,
|
|
eval_data=args.eval_data,
|
|
corpus=no_instruction_corpus,
|
|
save_name=args.save_name,
|
|
output_dir=args.output_dir,
|
|
save_to_output=args.save_to_output,
|
|
max_neg_num=args.max_neg_num,
|
|
cache_dir=args.dataset_cache_dir,
|
|
filter_answers=args.filter_answers,
|
|
)(query_ids, preds)
|
|
|
|
if log:
|
|
file_logger = FileLogger(makedirs(args.log_path))
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
else:
|
|
metrics = {}
|
|
|
|
accelerator.wait_for_everyone()
|
|
return query_ids, preds, metrics
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser([Args])
|
|
args, = parser.parse_args_into_dataclasses()
|
|
main(args)
|