241 lines
10 KiB
Python
Raw Normal View History

2023-10-10 12:23:43 +00:00
import os
import json
import logging
import random
import datasets
from tqdm import tqdm
from datetime import timedelta
from accelerate import Accelerator, InitProcessGroupKwargs
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
from collections import defaultdict
from transformers import HfArgumentParser
from src.lm import LM, LMArgs
from src.utils.util import split_file_dir_name_ext, makedirs, save_pickle, load_pickle, remove_eos, DefaultDataCollator, DatasetProcessFn
logger = logging.getLogger(__name__)
@dataclass
class ScoreArgs(LMArgs):
eval_data: str = field(
default=None,
metadata={'help': 'Query jsonl.'}
)
context_max_length: int = field(
default=1024,
metadata={'help': 'Max length for lm.'}
)
key_max_length: int = field(
default=512,
metadata={'help': 'Max length for key.'}
)
lm_batch_size: int = field(
default=4,
metadata={'help': 'Evaluation json file.'},
)
save_name: str = field(
default="llama2-7b-chat",
metadata={'help': 'Name of the scored file.'}
)
load_score: bool = field(
default=False,
metadata={'help': 'Load score from temperary file?'}
)
def process_lm_scoring(tokenizer, key_max_length=512):
test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]
has_bos = has_eos = False
if test[0] == 1:
has_bos = True
if test[-1] == 1:
has_eos = True
@DatasetProcessFn(augment=True)
def _process(query, answers, query_id, task, pos=None, neg=None, history=None, context_inputs=None, query_inputs=None, answer_inputs=None, score_inputs=None, _index=None, **kwds):
"""Yield each key (pos&neg)"""
if task in ["qa", "convsearch"]:
template = "Knowledge: {key.strip()}\n\nQuestion: {query.strip()}\n\nAnswer: {answer.strip()}"
2023-12-10 12:16:09 +00:00
elif task == "icl":
template = "{key}\n{query}\n{answer}"
2023-10-10 12:23:43 +00:00
elif task == "lrlm":
# template = "{key}{continuation[i]}{context}{query}{answer}"
pass
elif task == "chat":
template = "{key}\nSpeaker 1: {query}\nSpeaker 2: {answer}"
else:
raise NotImplementedError(f"Task type {task} not implemented!")
output = defaultdict(list)
# NOTE: sample 1 answer for scoring if there are multiple
if len(answers) > 1:
answer = random.choice(answers)
else:
answer = answers[0]
if history is not None:
assert task == "chat", f"Found history={history} is not None but task={task} is not 'chat'!"
keys = history
else:
keys = pos + neg
for i, key in enumerate(keys):
# NOTE: do not add special tokens!
if task == "lrlm":
score_input = score_inputs[i]
input_ids = score_input + context_inputs + query_inputs + answer_inputs
attention_mask = [1 for _ in input_ids]
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
labels = input_ids.copy()
answer_length = len(answer_inputs)
labels[:-answer_length] = [-100] * (len(labels) - answer_length)
inputs["labels"] = labels
else:
# truncate key
key = tokenizer.decode(tokenizer.encode(key, add_special_tokens=False, max_length=key_max_length, truncation=True))
seq = eval(f"f{repr(template)}")
inputs = tokenizer(seq, return_token_type_ids=False)
if has_eos:
inputs = remove_eos(inputs, tokenizer.eos_token_id)
# find answer length
answer_seq = tokenizer.encode("Answer: " + answer.lstrip(" "), add_special_tokens=False)
answer_length = len(answer_seq) - len(tokenizer.encode("Answer:", add_special_tokens=False))
assert answer_length > 0, f"No answer found in inputs {_index}!"
# take care of padded tokens
labels = inputs["input_ids"].copy()
labels = [x if inputs["attention_mask"][i] == 1 else -100 for i, x in enumerate(labels)]
labels[:-answer_length] = [-100] * (len(labels) - answer_length)
inputs["labels"] = labels
for k, v in inputs.items():
output[k].append(v)
output["query_id"].append(query_id)
return output
return _process
def collate_scores(eval_data, save_name):
"""
Collate the lm scorings based on query_ids.
Append a 'teacher_score' column in the eval_data and save at eval_data.save_name.json.
"""
def collate(query_ids, scores):
# only on main process
eval_data_folder, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
data_save_path = os.path.join(eval_data_folder, f"{eval_data_name}.scored.{save_name}" + eval_data_ext)
makedirs(data_save_path)
prev_query_id = None
teacher_scores = []
try:
logger.info(f"saving data to {data_save_path}...")
with open(eval_data) as f, open(data_save_path, "w") as g:
2023-12-10 12:32:16 +00:00
for query_id, score in tqdm(zip(query_ids, scores)):
2023-10-10 12:23:43 +00:00
if (query_id != prev_query_id) and (prev_query_id is not None):
sample = json.loads(f.readline().strip())
assert prev_query_id == sample["query_id"], f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
if "history" in sample:
assert len(sample["history"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
else:
assert len(sample["pos"] + sample["neg"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['pos'] + sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
sample["teacher_scores"] = teacher_scores.copy()
if sample["task"] == "lrlm" and "query_inputs" in sample:
del sample["query_inputs"]
del sample["answer_inputs"]
del sample["context_inputs"]
del sample["score_inputs"]
g.write(json.dumps(sample, ensure_ascii=False) + "\n")
teacher_scores.clear()
# accumulate scores of different keys for the same query
# log likelihood
teacher_scores.append(-score)
prev_query_id = query_id
# NOTE: the last line
sample = json.loads(f.readline().strip())
assert prev_query_id == sample["query_id"], f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
if "history" in sample:
assert len(sample["history"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
else:
assert len(sample["pos"] + sample["neg"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['pos'] + sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
sample["teacher_scores"] = teacher_scores.copy()
if sample["task"] == "lrlm" and "query_inputs" in sample:
del sample["query_inputs"]
del sample["answer_inputs"]
del sample["context_inputs"]
del sample["score_inputs"]
g.write(json.dumps(sample, ensure_ascii=False) + "\n")
teacher_scores.clear()
except:
save_path = os.path.join(eval_data_folder, f"{eval_data_name}.{save_name}.pkl")
logger.error(f"Error when trying to save to json file. Save scores to {save_path} instead!")
save_pickle((query_ids, scores), save_path)
raise
return collate
def main():
parser = HfArgumentParser([ScoreArgs])
args, = parser.parse_args_into_dataclasses()
args: ScoreArgs
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
logger.info(f"Loading data from {args.eval_data}...")
llm = LM(
model_name_or_path=args.model_name_or_path,
dtype=args.lm_dtype,
padding_side=args.padding_side,
cache_dir=args.model_cache_dir,
accelerator=accelerator
)
llm.to(accelerator.device)
tokenizer = llm.tokenizer
logging.info(f"Loading data from {args.eval_data}...")
if args.load_score:
eval_data_folder, eval_data_name, eval_data_ext = split_file_dir_name_ext(args.eval_data)
save_path = os.path.join(eval_data_folder, f"{eval_data_name}.{args.save_name}.pkl")
results = load_pickle(save_path)
else:
with accelerator.main_process_first():
# dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train[:100]", cache_dir=args.dataset_cache_dir)
dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
dataset = dataset.map(
process_lm_scoring(tokenizer=tokenizer, key_max_length=args.key_max_length),
remove_columns=dataset.column_names,
batched=True,
num_proc=32,
with_indices=True
)
data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
dataloader = DataLoader(
dataset,
batch_size=args.lm_batch_size,
collate_fn=data_collator,
pin_memory=True,
)
dataloader = accelerator.prepare(dataloader)
query_ids, scores = llm.compute_nlls(dataloader)
if accelerator.process_index == 0:
collate_scores(args.eval_data, args.save_name)(query_ids, scores)
if __name__ == "__main__":
main()