2023-10-10 12:23:43 +00:00
import os
import json
import logging
import random
import datasets
from tqdm import tqdm
from datetime import timedelta
from accelerate import Accelerator , InitProcessGroupKwargs
from torch . utils . data import DataLoader
from dataclasses import dataclass , field
from collections import defaultdict
from transformers import HfArgumentParser
from src . lm import LM , LMArgs
from src . utils . util import split_file_dir_name_ext , makedirs , save_pickle , load_pickle , remove_eos , DefaultDataCollator , DatasetProcessFn
logger = logging . getLogger ( __name__ )
@dataclass
class ScoreArgs ( LMArgs ) :
eval_data : str = field (
default = None ,
metadata = { ' help ' : ' Query jsonl. ' }
)
context_max_length : int = field (
default = 1024 ,
metadata = { ' help ' : ' Max length for lm. ' }
)
key_max_length : int = field (
default = 512 ,
metadata = { ' help ' : ' Max length for key. ' }
)
lm_batch_size : int = field (
default = 4 ,
metadata = { ' help ' : ' Evaluation json file. ' } ,
)
save_name : str = field (
default = " llama2-7b-chat " ,
metadata = { ' help ' : ' Name of the scored file. ' }
)
load_score : bool = field (
default = False ,
metadata = { ' help ' : ' Load score from temperary file? ' }
)
def process_lm_scoring ( tokenizer , key_max_length = 512 ) :
test = tokenizer ( " test " , return_special_tokens_mask = True ) [ " special_tokens_mask " ]
has_bos = has_eos = False
if test [ 0 ] == 1 :
has_bos = True
if test [ - 1 ] == 1 :
has_eos = True
@DatasetProcessFn ( augment = True )
def _process ( query , answers , query_id , task , pos = None , neg = None , history = None , context_inputs = None , query_inputs = None , answer_inputs = None , score_inputs = None , _index = None , * * kwds ) :
""" Yield each key (pos&neg) """
if task in [ " qa " , " convsearch " ] :
template = " Knowledge: { key.strip()} \n \n Question: { query.strip()} \n \n Answer: { answer.strip()} "
2023-12-10 12:16:09 +00:00
elif task == " icl " :
template = " {key} \n {query} \n {answer} "
2023-10-10 12:23:43 +00:00
elif task == " lrlm " :
# template = "{key}{continuation[i]}{context}{query}{answer}"
pass
elif task == " chat " :
template = " {key} \n Speaker 1: {query} \n Speaker 2: {answer} "
else :
raise NotImplementedError ( f " Task type { task } not implemented! " )
output = defaultdict ( list )
# NOTE: sample 1 answer for scoring if there are multiple
if len ( answers ) > 1 :
answer = random . choice ( answers )
else :
answer = answers [ 0 ]
if history is not None :
assert task == " chat " , f " Found history= { history } is not None but task= { task } is not ' chat ' ! "
keys = history
else :
keys = pos + neg
for i , key in enumerate ( keys ) :
# NOTE: do not add special tokens!
if task == " lrlm " :
score_input = score_inputs [ i ]
input_ids = score_input + context_inputs + query_inputs + answer_inputs
attention_mask = [ 1 for _ in input_ids ]
inputs = {
" input_ids " : input_ids ,
" attention_mask " : attention_mask
}
labels = input_ids . copy ( )
answer_length = len ( answer_inputs )
labels [ : - answer_length ] = [ - 100 ] * ( len ( labels ) - answer_length )
inputs [ " labels " ] = labels
else :
# truncate key
key = tokenizer . decode ( tokenizer . encode ( key , add_special_tokens = False , max_length = key_max_length , truncation = True ) )
seq = eval ( f " f { repr ( template ) } " )
inputs = tokenizer ( seq , return_token_type_ids = False )
if has_eos :
inputs = remove_eos ( inputs , tokenizer . eos_token_id )
# find answer length
answer_seq = tokenizer . encode ( " Answer: " + answer . lstrip ( " " ) , add_special_tokens = False )
answer_length = len ( answer_seq ) - len ( tokenizer . encode ( " Answer: " , add_special_tokens = False ) )
assert answer_length > 0 , f " No answer found in inputs { _index } ! "
# take care of padded tokens
labels = inputs [ " input_ids " ] . copy ( )
labels = [ x if inputs [ " attention_mask " ] [ i ] == 1 else - 100 for i , x in enumerate ( labels ) ]
labels [ : - answer_length ] = [ - 100 ] * ( len ( labels ) - answer_length )
inputs [ " labels " ] = labels
for k , v in inputs . items ( ) :
output [ k ] . append ( v )
output [ " query_id " ] . append ( query_id )
return output
return _process
def collate_scores ( eval_data , save_name ) :
"""
Collate the lm scorings based on query_ids .
Append a ' teacher_score ' column in the eval_data and save at eval_data . save_name . json .
"""
def collate ( query_ids , scores ) :
# only on main process
eval_data_folder , eval_data_name , eval_data_ext = split_file_dir_name_ext ( eval_data )
data_save_path = os . path . join ( eval_data_folder , f " { eval_data_name } .scored. { save_name } " + eval_data_ext )
makedirs ( data_save_path )
prev_query_id = None
teacher_scores = [ ]
try :
logger . info ( f " saving data to { data_save_path } ... " )
with open ( eval_data ) as f , open ( data_save_path , " w " ) as g :
2023-12-10 12:32:16 +00:00
for query_id , score in tqdm ( zip ( query_ids , scores ) ) :
2023-10-10 12:23:43 +00:00
if ( query_id != prev_query_id ) and ( prev_query_id is not None ) :
sample = json . loads ( f . readline ( ) . strip ( ) )
assert prev_query_id == sample [ " query_id " ] , f " Found incompatible query_id from data ( { sample [ ' query_id ' ] } ) and from eval_preds ( { prev_query_id } ) "
if " history " in sample :
assert len ( sample [ " history " ] ) == len ( teacher_scores ) , f " Found incompatible key number from data ( { len ( sample [ ' history ' ] ) } ) and from eval_preds ( { len ( teacher_scores ) } ) "
else :
assert len ( sample [ " pos " ] + sample [ " neg " ] ) == len ( teacher_scores ) , f " Found incompatible key number from data ( { len ( sample [ ' pos ' ] + sample [ ' neg ' ] ) } ) and from eval_preds ( { len ( teacher_scores ) } ) "
sample [ " teacher_scores " ] = teacher_scores . copy ( )
if sample [ " task " ] == " lrlm " and " query_inputs " in sample :
del sample [ " query_inputs " ]
del sample [ " answer_inputs " ]
del sample [ " context_inputs " ]
del sample [ " score_inputs " ]
g . write ( json . dumps ( sample , ensure_ascii = False ) + " \n " )
teacher_scores . clear ( )
# accumulate scores of different keys for the same query
# log likelihood
teacher_scores . append ( - score )
prev_query_id = query_id
# NOTE: the last line
sample = json . loads ( f . readline ( ) . strip ( ) )
assert prev_query_id == sample [ " query_id " ] , f " Found incompatible query_id from data ( { sample [ ' query_id ' ] } ) and from eval_preds ( { prev_query_id } ) "
if " history " in sample :
assert len ( sample [ " history " ] ) == len ( teacher_scores ) , f " Found incompatible key number from data ( { len ( sample [ ' history ' ] ) } ) and from eval_preds ( { len ( teacher_scores ) } ) "
else :
assert len ( sample [ " pos " ] + sample [ " neg " ] ) == len ( teacher_scores ) , f " Found incompatible key number from data ( { len ( sample [ ' pos ' ] + sample [ ' neg ' ] ) } ) and from eval_preds ( { len ( teacher_scores ) } ) "
sample [ " teacher_scores " ] = teacher_scores . copy ( )
if sample [ " task " ] == " lrlm " and " query_inputs " in sample :
del sample [ " query_inputs " ]
del sample [ " answer_inputs " ]
del sample [ " context_inputs " ]
del sample [ " score_inputs " ]
g . write ( json . dumps ( sample , ensure_ascii = False ) + " \n " )
teacher_scores . clear ( )
except :
save_path = os . path . join ( eval_data_folder , f " { eval_data_name } . { save_name } .pkl " )
logger . error ( f " Error when trying to save to json file. Save scores to { save_path } instead! " )
save_pickle ( ( query_ids , scores ) , save_path )
raise
return collate
def main ( ) :
parser = HfArgumentParser ( [ ScoreArgs ] )
args , = parser . parse_args_into_dataclasses ( )
args : ScoreArgs
accelerator = Accelerator ( cpu = args . cpu , kwargs_handlers = [ InitProcessGroupKwargs ( timeout = timedelta ( seconds = 100000 ) ) ] )
logger . info ( f " Loading data from { args . eval_data } ... " )
llm = LM (
model_name_or_path = args . model_name_or_path ,
dtype = args . lm_dtype ,
padding_side = args . padding_side ,
cache_dir = args . model_cache_dir ,
accelerator = accelerator
)
llm . to ( accelerator . device )
tokenizer = llm . tokenizer
logging . info ( f " Loading data from { args . eval_data } ... " )
if args . load_score :
eval_data_folder , eval_data_name , eval_data_ext = split_file_dir_name_ext ( args . eval_data )
save_path = os . path . join ( eval_data_folder , f " { eval_data_name } . { args . save_name } .pkl " )
results = load_pickle ( save_path )
else :
with accelerator . main_process_first ( ) :
# dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train[:100]", cache_dir=args.dataset_cache_dir)
dataset = datasets . load_dataset ( " json " , data_files = args . eval_data , split = " train " , cache_dir = args . dataset_cache_dir )
dataset = dataset . map (
process_lm_scoring ( tokenizer = tokenizer , key_max_length = args . key_max_length ) ,
remove_columns = dataset . column_names ,
batched = True ,
num_proc = 32 ,
with_indices = True
)
data_collator = DefaultDataCollator ( tokenizer = tokenizer , add_position_ids = args . add_position_ids )
dataloader = DataLoader (
dataset ,
batch_size = args . lm_batch_size ,
collate_fn = data_collator ,
pin_memory = True ,
)
dataloader = accelerator . prepare ( dataloader )
query_ids , scores = llm . compute_nlls ( dataloader )
if accelerator . process_index == 0 :
collate_scores ( args . eval_data , args . save_name ) ( query_ids , scores )
if __name__ == " __main__ " :
main ( )