2024-10-27 20:28:57 +08:00

164 lines
5.3 KiB
Python

import os
import torch
import logging
import datasets
from typing import List
from accelerate import Accelerator
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
from src.retrieval import (
RetrievalArgs,
Retriever,
RetrievalDataset,
RetrievalMetric,
TASK_CONFIG,
)
from src.utils.util import makedirs, FileLogger
logger = logging.getLogger(__name__)
@dataclass
class Args(RetrievalArgs):
eval_data: str = field(
default=None,
metadata={'help': 'Query jsonl.'}
)
output_dir: str = field(
default="data/outputs/",
)
corpus: str = field(
default=None,
metadata={'help': 'Corpus path for retrieval.'}
)
key_template: str = field(
default="{title} {text}",
metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
)
log_path: str = field(
default="data/results/performance.log",
metadata={'help': 'Path to the file for logging.'}
)
def main(args, accelerator=None, log=True):
if accelerator is None:
accelerator = Accelerator(cpu=args.cpu)
with accelerator.main_process_first():
config = TASK_CONFIG[args.version]
instruction = config["instruction"]
# we should get the evaluation task before specifying instruction
# NOTE: only dense retrieval needs instruction
if args.eval_data is not None and args.add_instruction and args.retrieval_method == "dense":
raw_eval_dataset = datasets.load_dataset('json', data_files=args.eval_data, split='train', cache_dir=args.dataset_cache_dir)
eval_task = raw_eval_dataset[0]["task"]
else:
eval_task = None
eval_dataset = RetrievalDataset.prepare_eval_dataset(
data_file=args.eval_data,
cache_dir=args.dataset_cache_dir,
instruction=instruction[eval_task] if eval_task is not None else None,
)
corpus = RetrievalDataset.prepare_corpus(
data_file=args.corpus,
key_template=args.key_template,
cache_dir=args.dataset_cache_dir,
instruction=instruction[eval_task] if eval_task is not None else None
)
result_path = RetrievalMetric._get_save_path(args.eval_data, args.output_dir, field="result", save_name=args.save_name)
if args.load_result:
query_ids, preds = RetrievalMetric._load_result(result_path)
else:
retriever = Retriever(
retrieval_method=args.retrieval_method,
# for dense retriever
query_encoder=args.query_encoder,
key_encoder=args.key_encoder,
pooling_method=args.pooling_method,
dense_metric=args.dense_metric,
query_max_length=args.query_max_length,
key_max_length=args.key_max_length,
tie_encoders=args.tie_encoders,
truncation_side=args.truncation_side,
cache_dir=args.model_cache_dir,
dtype=args.dtype,
accelerator=accelerator,
# for bm25 retriever
anserini_dir=args.anserini_dir,
k1=args.k1,
b=args.b
)
retriever.index(
corpus,
output_dir=args.output_dir,
# for dense retriever
embedding_name=args.embedding_name,
index_factory=args.faiss_index_factory,
load_encode=args.load_encode,
save_encode=args.save_encode,
load_index=args.load_index,
save_index=args.save_index,
batch_size=args.batch_size,
# for bm25 retriever
threads=args.threads,
language=args.language,
storeDocvectors=args.storeDocvectors,
load_collection=args.load_collection,
)
query_ids, preds = retriever.search(
eval_dataset=eval_dataset,
hits=args.hits,
# for dense retriever
batch_size=args.batch_size,
)
del retriever
torch.cuda.empty_cache()
if args.save_result and accelerator.process_index == 0:
RetrievalMetric._save_result(query_ids, preds, result_path)
if accelerator.process_index == 0:
# NOTE: this corpus is for computing metrics, where no instruction is given
no_instruction_corpus = RetrievalDataset.prepare_corpus(
data_file=args.corpus,
key_template=args.key_template,
cache_dir=args.dataset_cache_dir,
)
metrics = RetrievalMetric.get_metric_fn(
args.metrics,
cutoffs=args.cutoffs,
eval_data=args.eval_data,
corpus=no_instruction_corpus,
save_name=args.save_name,
output_dir=args.output_dir,
save_to_output=args.save_to_output,
max_neg_num=args.max_neg_num,
cache_dir=args.dataset_cache_dir,
filter_answers=args.filter_answers,
)(query_ids, preds)
if log:
file_logger = FileLogger(makedirs(args.log_path))
file_logger.log(metrics, Args=asdict(args))
else:
metrics = {}
accelerator.wait_for_everyone()
return query_ids, preds, metrics
if __name__ == "__main__":
parser = HfArgumentParser([Args])
args, = parser.parse_args_into_dataclasses()
main(args)