mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
import os
|
|
import logging
|
|
import datasets
|
|
import torch
|
|
import numpy as np
|
|
from accelerate import Accelerator
|
|
from torch.utils.data import DataLoader
|
|
from transformers import HfArgumentParser
|
|
from dataclasses import dataclass, field, asdict
|
|
|
|
from src.lm import SRLMArgs, SelfRetrievalLM
|
|
from src.retrieval import Retriever, RetrievalArgs, TASK_CONFIG
|
|
from src.utils.util import makedirs, pad_nested_lists, get_max_length_in_nested_lists, FileLogger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
import transformers
|
|
# disable too long input warning
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
|
# merge two args to get unified arguments
|
|
@dataclass
|
|
class LRLMArgs(RetrievalArgs, SRLMArgs):
|
|
eval_data: str = field(
|
|
default="llm-embedder:chat/msc/test.json",
|
|
metadata={'help': 'Evaluation file containing long texts.'}
|
|
)
|
|
lm_batch_size: int = field(
|
|
default=1,
|
|
metadata={'help': 'Evaluation batch size.'},
|
|
)
|
|
add_position_ids: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Create position ids based on attention masks? Useful when training left-padded models with absolute position embeddings.'}
|
|
)
|
|
key_num: int = field(
|
|
default=1,
|
|
metadata={'help': 'How many chunks to retrieve at a time?'}
|
|
)
|
|
log_path: str = field(
|
|
default="data/results/msc/msc.log",
|
|
metadata={'help': 'Path to the file for logging.'}
|
|
)
|
|
debug_retrieval: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Check retrieval queries and values?'}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class HistoryCollator:
|
|
"""Collate histories, pad them, and return masks"""
|
|
def __call__(self, batch_elem):
|
|
first_elem = batch_elem[0]
|
|
return_batch = {}
|
|
|
|
for key, value in first_elem.items():
|
|
batch_value = [elem[key] for elem in batch_elem]
|
|
if key == "history":
|
|
longest = get_max_length_in_nested_lists(batch_value)
|
|
batch_value, history_mask = pad_nested_lists(batch_value, longest, "", "right")
|
|
history_mask = torch.tensor(history_mask, dtype=torch.bool)
|
|
return_batch["history_mask"] = history_mask
|
|
|
|
elif key == "answers":
|
|
# there is only one answer
|
|
key = "answer"
|
|
batch_value = [elem[0] for elem in batch_value]
|
|
|
|
elif key in ["query_id", "task"]:
|
|
continue
|
|
|
|
# strip here for convenience
|
|
return_batch[key] = np.char.strip(np.array(batch_value))
|
|
return return_batch
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([LRLMArgs])
|
|
args, = parser.parse_args_into_dataclasses()
|
|
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
|
|
retriever = Retriever(
|
|
retrieval_method=args.retrieval_method,
|
|
# for dense retriever
|
|
query_encoder=args.query_encoder,
|
|
key_encoder=args.key_encoder,
|
|
pooling_method=args.pooling_method,
|
|
dense_metric=args.dense_metric,
|
|
query_max_length=args.query_max_length,
|
|
key_max_length=args.key_max_length,
|
|
tie_encoders=args.tie_encoders,
|
|
truncation_side=args.truncation_side,
|
|
cache_dir=args.model_cache_dir,
|
|
dtype=args.dtype,
|
|
accelerator=accelerator,
|
|
# for bm25 retriever
|
|
anserini_dir=args.anserini_dir,
|
|
k1=args.k1,
|
|
b=args.b
|
|
)
|
|
|
|
if args.add_instruction:
|
|
instruction = TASK_CONFIG[args.version]["instruction"]["chat"]
|
|
else:
|
|
instruction = None
|
|
|
|
lm = SelfRetrievalLM(
|
|
model_name_or_path=args.model_name_or_path,
|
|
retriever=retriever,
|
|
dtype=args.lm_dtype,
|
|
device_map=args.lm_device_map,
|
|
padding_side=args.padding_side,
|
|
cache_dir=args.model_cache_dir,
|
|
context_window_size=args.context_window_size,
|
|
chunk_size=args.chunk_size,
|
|
key_num=args.key_num,
|
|
chunk_batch_size=args.chunk_batch_size,
|
|
retrieval_method=args.retrieval_method,
|
|
order_method=args.order_method,
|
|
integrate_method=args.integrate_method,
|
|
instruction=instruction,
|
|
debug_retrieval=args.debug_retrieval,
|
|
add_sep=args.add_sep,
|
|
accelerator=accelerator,
|
|
)
|
|
|
|
logging.info(f"Loading data from {args.eval_data}...")
|
|
|
|
with accelerator.main_process_first():
|
|
dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
|
|
|
|
data_collator = HistoryCollator()
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.lm_batch_size,
|
|
collate_fn=data_collator,
|
|
pin_memory=True,
|
|
)
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
perplexity = lm.compute_perplexity(dataloader)
|
|
metrics = {"perplexity": perplexity}
|
|
|
|
if accelerator.process_index == 0:
|
|
log_path = os.path.join(args.log_path)
|
|
|
|
file_logger = FileLogger(makedirs(log_path))
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|