2024-10-27 20:28:57 +08:00

191 lines
6.2 KiB
Python

import os
import logging
import datasets
from copy import deepcopy
from accelerate import Accelerator
from torch.utils.data import DataLoader
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
from src.lm import SRLMArgs, SelfRetrievalLM
from src.retrieval import Retriever, RetrievalArgs, TASK_CONFIG
from src.utils.util import makedirs, remove_eos, DefaultDataCollator, DatasetProcessFn, FileLogger
logger = logging.getLogger(__name__)
import transformers
# disable too long input warning
transformers.logging.set_verbosity_error()
# merge two args to get unified arguments
@dataclass
class LRLMArgs(RetrievalArgs, SRLMArgs):
eval_data: str = field(
default="llm-embedder:lrlm/books3/test.json",
metadata={'help': 'Evaluation json file.'},
)
lm_batch_size: int = field(
default=1,
metadata={'help': 'Evaluation json file.'},
)
context_max_length: int = field(
default=32768,
metadata={'help': 'Evaluation json file.'},
)
anchor_length: int = field(
default=160000,
metadata={'help': 'Evaluation file containing long texts.'}
)
chunk_size: int = field(
default=128,
metadata={'help': 'How many tokens in a chunk?'}
)
key_num: int = field(
default=8,
metadata={'help': 'How many chunks to retrieve at a time?'}
)
chunk_batch_size: int = field(
default=1,
metadata={'help': 'How many retrieval & generation to execute in parallel?'}
)
log_path: str = field(
default="data/results/lrlm",
metadata={'help': 'Path to the file for logging.'}
)
debug_retrieval: bool = field(
default=False,
metadata={'help': 'Check retrieval queries and values?'}
)
def __post_init__(self):
super().__post_init__()
if self.retrieval_method == "bm25":
# NOTE: we can only use naive bm25 for self retrieval
self.retrieval_method = "naive-bm25"
def process_lrlm(tokenizer, context_max_length=4096, target_length=1024, anchor_length=160000):
test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]
has_eos = False
if test[-1] == 1:
has_eos = True
left_truncation_tokenizer = deepcopy(tokenizer)
left_truncation_tokenizer.truncation_side = "left"
@DatasetProcessFn()
def _process(text, **kwds):
output = {}
text = text[:anchor_length]
inputs = left_truncation_tokenizer(text, max_length=context_max_length, truncation=True, return_token_type_ids=False, add_special_tokens=False)
if len(inputs.input_ids) < target_length:
return None
labels = inputs["input_ids"].copy()
inputs_length = len(labels)
labels[:-target_length] = [-100 for _ in range(inputs_length - target_length)]
inputs["labels"] = labels
for k, v in inputs.items():
output[k] = v
return output
return _process
def main():
parser = HfArgumentParser([LRLMArgs])
args, = parser.parse_args_into_dataclasses()
accelerator = Accelerator(cpu=args.cpu)
retriever = Retriever(
retrieval_method=args.retrieval_method,
# for dense retriever
query_encoder=args.query_encoder,
key_encoder=args.key_encoder,
pooling_method=args.pooling_method,
dense_metric=args.dense_metric,
query_max_length=args.query_max_length,
key_max_length=args.key_max_length,
tie_encoders=args.tie_encoders,
truncation_side=args.truncation_side,
cache_dir=args.model_cache_dir,
dtype=args.dtype,
accelerator=accelerator,
# for bm25 retriever
anserini_dir=args.anserini_dir,
k1=args.k1,
b=args.b
)
if args.add_instruction:
instruction = TASK_CONFIG[args.version]["instruction"]["lrlm"]
else:
instruction = None
srlm = SelfRetrievalLM(
model_name_or_path=args.model_name_or_path,
retriever=retriever,
dtype=args.lm_dtype,
device_map=args.lm_device_map,
padding_side=args.padding_side,
cache_dir=args.model_cache_dir,
context_window_size=args.context_window_size,
chunk_size=args.chunk_size,
key_num=args.key_num,
chunk_batch_size=args.chunk_batch_size,
add_key_continuation=args.add_key_continuation,
retrieval_method=args.retrieval_method,
order_method=args.order_method,
integrate_method=args.integrate_method,
instruction=instruction,
debug_retrieval=args.debug_retrieval,
add_sep=args.add_sep,
accelerator=accelerator,
)
tokenizer = srlm.tokenizer
logging.info(f"Loading data from {args.eval_data}...")
if args.retrieval_method == "no" and args.context_max_length != args.context_window_size:
logger.warning(f"Found retrieval_method is 'no', setting context_max_length to the same as context_window_size ({args.context_window_size})!")
args.context_max_length = args.context_window_size
with accelerator.main_process_first():
dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
dataset = dataset.map(process_lrlm(
tokenizer,
context_max_length=args.context_max_length,
target_length=args.target_length,
anchor_length=args.anchor_length,
), remove_columns=dataset.column_names, batched=True, batch_size=50, num_proc=64)
data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
dataloader = DataLoader(
dataset,
batch_size=args.lm_batch_size,
collate_fn=data_collator,
pin_memory=True,
)
dataloader = accelerator.prepare(dataloader)
perplexity = srlm.compute_perplexity(dataloader)
metrics = {"perplexity": perplexity}
if accelerator.process_index == 0:
dataset = os.path.normpath(args.eval_data).split(os.sep)[-2]
log_path = os.path.join(args.log_path, f"{dataset}.log")
file_logger = FileLogger(makedirs(log_path))
file_logger.log(metrics, Args=asdict(args))
if __name__ == "__main__":
main()