mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
241 lines
10 KiB
Python
241 lines
10 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
import random
|
|
import datasets
|
|
from tqdm import tqdm
|
|
from datetime import timedelta
|
|
from accelerate import Accelerator, InitProcessGroupKwargs
|
|
from torch.utils.data import DataLoader
|
|
from dataclasses import dataclass, field
|
|
from collections import defaultdict
|
|
from transformers import HfArgumentParser
|
|
from src.lm import LM, LMArgs
|
|
from src.utils.util import split_file_dir_name_ext, makedirs, save_pickle, load_pickle, remove_eos, DefaultDataCollator, DatasetProcessFn
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ScoreArgs(LMArgs):
|
|
eval_data: str = field(
|
|
default=None,
|
|
metadata={'help': 'Query jsonl.'}
|
|
)
|
|
context_max_length: int = field(
|
|
default=1024,
|
|
metadata={'help': 'Max length for lm.'}
|
|
)
|
|
key_max_length: int = field(
|
|
default=512,
|
|
metadata={'help': 'Max length for key.'}
|
|
)
|
|
lm_batch_size: int = field(
|
|
default=4,
|
|
metadata={'help': 'Evaluation json file.'},
|
|
)
|
|
save_name: str = field(
|
|
default="llama2-7b-chat",
|
|
metadata={'help': 'Name of the scored file.'}
|
|
)
|
|
load_score: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Load score from temperary file?'}
|
|
)
|
|
|
|
|
|
def process_lm_scoring(tokenizer, key_max_length=512):
|
|
test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]
|
|
has_bos = has_eos = False
|
|
if test[0] == 1:
|
|
has_bos = True
|
|
if test[-1] == 1:
|
|
has_eos = True
|
|
|
|
@DatasetProcessFn(augment=True)
|
|
def _process(query, answers, query_id, task, pos=None, neg=None, history=None, context_inputs=None, query_inputs=None, answer_inputs=None, score_inputs=None, _index=None, **kwds):
|
|
"""Yield each key (pos&neg)"""
|
|
if task in ["qa", "convsearch"]:
|
|
template = "Knowledge: {key.strip()}\n\nQuestion: {query.strip()}\n\nAnswer: {answer.strip()}"
|
|
elif task == "icl":
|
|
template = "{key}\n{query}\n{answer}"
|
|
elif task == "lrlm":
|
|
# template = "{key}{continuation[i]}{context}{query}{answer}"
|
|
pass
|
|
elif task == "chat":
|
|
template = "{key}\nSpeaker 1: {query}\nSpeaker 2: {answer}"
|
|
else:
|
|
raise NotImplementedError(f"Task type {task} not implemented!")
|
|
|
|
output = defaultdict(list)
|
|
# NOTE: sample 1 answer for scoring if there are multiple
|
|
if len(answers) > 1:
|
|
answer = random.choice(answers)
|
|
else:
|
|
answer = answers[0]
|
|
|
|
if history is not None:
|
|
assert task == "chat", f"Found history={history} is not None but task={task} is not 'chat'!"
|
|
keys = history
|
|
else:
|
|
keys = pos + neg
|
|
for i, key in enumerate(keys):
|
|
# NOTE: do not add special tokens!
|
|
if task == "lrlm":
|
|
score_input = score_inputs[i]
|
|
input_ids = score_input + context_inputs + query_inputs + answer_inputs
|
|
attention_mask = [1 for _ in input_ids]
|
|
inputs = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask
|
|
}
|
|
labels = input_ids.copy()
|
|
answer_length = len(answer_inputs)
|
|
labels[:-answer_length] = [-100] * (len(labels) - answer_length)
|
|
inputs["labels"] = labels
|
|
else:
|
|
# truncate key
|
|
key = tokenizer.decode(tokenizer.encode(key, add_special_tokens=False, max_length=key_max_length, truncation=True))
|
|
|
|
seq = eval(f"f{repr(template)}")
|
|
inputs = tokenizer(seq, return_token_type_ids=False)
|
|
if has_eos:
|
|
inputs = remove_eos(inputs, tokenizer.eos_token_id)
|
|
|
|
# find answer length
|
|
answer_seq = tokenizer.encode("Answer: " + answer.lstrip(" "), add_special_tokens=False)
|
|
answer_length = len(answer_seq) - len(tokenizer.encode("Answer:", add_special_tokens=False))
|
|
assert answer_length > 0, f"No answer found in inputs {_index}!"
|
|
|
|
# take care of padded tokens
|
|
labels = inputs["input_ids"].copy()
|
|
labels = [x if inputs["attention_mask"][i] == 1 else -100 for i, x in enumerate(labels)]
|
|
labels[:-answer_length] = [-100] * (len(labels) - answer_length)
|
|
inputs["labels"] = labels
|
|
|
|
for k, v in inputs.items():
|
|
output[k].append(v)
|
|
output["query_id"].append(query_id)
|
|
return output
|
|
return _process
|
|
|
|
|
|
def collate_scores(eval_data, save_name):
|
|
"""
|
|
Collate the lm scorings based on query_ids.
|
|
Append a 'teacher_score' column in the eval_data and save at eval_data.save_name.json.
|
|
"""
|
|
def collate(query_ids, scores):
|
|
# only on main process
|
|
eval_data_folder, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
|
|
data_save_path = os.path.join(eval_data_folder, f"{eval_data_name}.scored.{save_name}" + eval_data_ext)
|
|
makedirs(data_save_path)
|
|
|
|
prev_query_id = None
|
|
teacher_scores = []
|
|
try:
|
|
logger.info(f"saving data to {data_save_path}...")
|
|
with open(eval_data) as f, open(data_save_path, "w") as g:
|
|
for query_id, score in tqdm(zip(query_ids, scores)):
|
|
if (query_id != prev_query_id) and (prev_query_id is not None):
|
|
sample = json.loads(f.readline().strip())
|
|
assert prev_query_id == sample["query_id"], f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
|
|
if "history" in sample:
|
|
assert len(sample["history"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
|
|
else:
|
|
assert len(sample["pos"] + sample["neg"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['pos'] + sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
|
|
sample["teacher_scores"] = teacher_scores.copy()
|
|
if sample["task"] == "lrlm" and "query_inputs" in sample:
|
|
del sample["query_inputs"]
|
|
del sample["answer_inputs"]
|
|
del sample["context_inputs"]
|
|
del sample["score_inputs"]
|
|
|
|
g.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|
teacher_scores.clear()
|
|
|
|
# accumulate scores of different keys for the same query
|
|
# log likelihood
|
|
teacher_scores.append(-score)
|
|
prev_query_id = query_id
|
|
|
|
# NOTE: the last line
|
|
sample = json.loads(f.readline().strip())
|
|
assert prev_query_id == sample["query_id"], f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
|
|
if "history" in sample:
|
|
assert len(sample["history"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
|
|
else:
|
|
assert len(sample["pos"] + sample["neg"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['pos'] + sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
|
|
sample["teacher_scores"] = teacher_scores.copy()
|
|
if sample["task"] == "lrlm" and "query_inputs" in sample:
|
|
del sample["query_inputs"]
|
|
del sample["answer_inputs"]
|
|
del sample["context_inputs"]
|
|
del sample["score_inputs"]
|
|
g.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|
teacher_scores.clear()
|
|
|
|
except:
|
|
save_path = os.path.join(eval_data_folder, f"{eval_data_name}.{save_name}.pkl")
|
|
logger.error(f"Error when trying to save to json file. Save scores to {save_path} instead!")
|
|
save_pickle((query_ids, scores), save_path)
|
|
raise
|
|
return collate
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([ScoreArgs])
|
|
args, = parser.parse_args_into_dataclasses()
|
|
args: ScoreArgs
|
|
|
|
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
|
|
logger.info(f"Loading data from {args.eval_data}...")
|
|
|
|
llm = LM(
|
|
model_name_or_path=args.model_name_or_path,
|
|
dtype=args.lm_dtype,
|
|
padding_side=args.padding_side,
|
|
cache_dir=args.model_cache_dir,
|
|
accelerator=accelerator
|
|
)
|
|
llm.to(accelerator.device)
|
|
|
|
tokenizer = llm.tokenizer
|
|
|
|
logging.info(f"Loading data from {args.eval_data}...")
|
|
|
|
if args.load_score:
|
|
eval_data_folder, eval_data_name, eval_data_ext = split_file_dir_name_ext(args.eval_data)
|
|
save_path = os.path.join(eval_data_folder, f"{eval_data_name}.{args.save_name}.pkl")
|
|
results = load_pickle(save_path)
|
|
|
|
else:
|
|
with accelerator.main_process_first():
|
|
# dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train[:100]", cache_dir=args.dataset_cache_dir)
|
|
dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
|
|
dataset = dataset.map(
|
|
process_lm_scoring(tokenizer=tokenizer, key_max_length=args.key_max_length),
|
|
remove_columns=dataset.column_names,
|
|
batched=True,
|
|
num_proc=32,
|
|
with_indices=True
|
|
)
|
|
|
|
data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.lm_batch_size,
|
|
collate_fn=data_collator,
|
|
pin_memory=True,
|
|
)
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
query_ids, scores = llm.compute_nlls(dataloader)
|
|
|
|
if accelerator.process_index == 0:
|
|
collate_scores(args.eval_data, args.save_name)(query_ids, scores)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|