2022-01-02 20:12:34 -05:00
from itertools import chain
2021-11-23 14:26:39 -05:00
from typing import Dict , Any
2022-03-20 22:03:02 -04:00
import numpy as np
2022-01-02 20:12:34 -05:00
from . . data import (
SUMMARIZATION ,
SEQREGRESSION ,
SEQCLASSIFICATION ,
MULTICHOICECLASSIFICATION ,
2022-01-03 13:44:10 -05:00
TOKENCLASSIFICATION ,
NLG_TASKS ,
2022-01-02 20:12:34 -05:00
)
2021-12-10 12:32:49 -05:00
2022-07-05 13:38:21 -04:00
import pandas as pd
2021-11-16 14:06:20 -05:00
2021-11-23 14:26:39 -05:00
def load_default_huggingface_metric_for_task ( task ) :
2022-01-03 13:44:10 -05:00
2021-11-23 14:26:39 -05:00
if task == SEQCLASSIFICATION :
2022-01-07 19:25:58 -05:00
return " accuracy "
2021-11-23 14:26:39 -05:00
elif task == SEQREGRESSION :
2022-01-14 00:08:51 -05:00
return " r2 "
2021-12-10 12:32:49 -05:00
elif task == SUMMARIZATION :
2022-03-20 22:03:02 -04:00
return " rouge1 "
2022-01-02 20:12:34 -05:00
elif task == MULTICHOICECLASSIFICATION :
2022-01-07 19:25:58 -05:00
return " accuracy "
2022-01-03 13:44:10 -05:00
elif task == TOKENCLASSIFICATION :
2022-01-07 19:25:58 -05:00
return " seqeval "
2021-11-16 14:06:20 -05:00
2022-03-20 22:03:02 -04:00
def tokenize_text ( X , Y = None , task = None , hf_args = None , tokenizer = None ) :
2021-11-16 14:06:20 -05:00
if task in ( SEQCLASSIFICATION , SEQREGRESSION ) :
2022-01-16 12:07:31 -05:00
X_tokenized = tokenize_onedataframe (
X ,
tokenizer = tokenizer ,
task = task ,
2022-03-20 22:03:02 -04:00
hf_args = hf_args ,
2022-01-16 12:07:31 -05:00
prefix_str = " " ,
2021-12-20 17:19:32 -05:00
)
return X_tokenized , None
2022-01-03 13:44:10 -05:00
elif task == TOKENCLASSIFICATION :
2022-01-16 12:07:31 -05:00
return tokenize_text_tokclassification (
2022-03-20 22:03:02 -04:00
X , Y , tokenizer = tokenizer , hf_args = hf_args
2022-01-16 12:07:31 -05:00
)
2021-12-20 17:19:32 -05:00
elif task in NLG_TASKS :
2022-03-20 22:03:02 -04:00
return tokenize_seq2seq ( X , Y , tokenizer = tokenizer , task = task , hf_args = hf_args )
2022-01-02 20:12:34 -05:00
elif task == MULTICHOICECLASSIFICATION :
2022-03-20 22:03:02 -04:00
return tokenize_text_multiplechoice ( X , tokenizer = tokenizer , hf_args = hf_args )
2021-12-20 17:19:32 -05:00
2022-03-20 22:03:02 -04:00
def tokenize_seq2seq ( X , Y , tokenizer , task = None , hf_args = None ) :
2022-01-16 12:07:31 -05:00
model_inputs = tokenize_onedataframe (
2021-12-20 17:19:32 -05:00
X ,
2022-01-16 12:07:31 -05:00
tokenizer = tokenizer ,
2021-12-20 17:19:32 -05:00
task = task ,
2022-03-20 22:03:02 -04:00
hf_args = hf_args ,
2022-01-16 12:07:31 -05:00
prefix_str = " summarize: " ,
2021-12-20 17:19:32 -05:00
)
2022-07-05 13:38:21 -04:00
model_outputs = None
2021-12-20 17:19:32 -05:00
if Y is not None :
2022-07-05 13:38:21 -04:00
model_outputs = tokenize_onedataframe (
2021-12-20 17:19:32 -05:00
Y . to_frame ( ) ,
2022-01-16 12:07:31 -05:00
tokenizer = tokenizer ,
2021-12-20 17:19:32 -05:00
task = task ,
2022-03-20 22:03:02 -04:00
hf_args = hf_args ,
2022-01-16 12:07:31 -05:00
prefix_str = " " ,
2021-12-20 17:19:32 -05:00
)
2022-07-05 13:38:21 -04:00
model_outputs [ " label " ] = [
2021-12-20 17:19:32 -05:00
[ ( each_l if each_l != tokenizer . pad_token_id else - 100 ) for each_l in label ]
2022-07-05 13:38:21 -04:00
for label in model_outputs [ " input_ids " ]
2021-12-20 17:19:32 -05:00
]
2022-07-05 13:38:21 -04:00
model_outputs = model_outputs . drop (
2021-12-20 17:19:32 -05:00
columns = [ " attention_mask " , " input_ids " , " decoder_input_ids " ]
)
2022-07-05 13:38:21 -04:00
return model_inputs , model_outputs
2021-12-20 17:19:32 -05:00
2022-01-03 13:44:10 -05:00
def tokenize_and_align_labels (
2022-03-20 22:03:02 -04:00
examples ,
tokenizer ,
2022-07-05 13:38:21 -04:00
label_to_id ,
b_to_i_label ,
2022-03-20 22:03:02 -04:00
hf_args = None ,
X_sent_key = None ,
Y_sent_key = None ,
return_column_name = False ,
2022-01-03 13:44:10 -05:00
) :
tokenized_inputs = tokenizer (
[ list ( examples [ X_sent_key ] ) ] ,
2022-05-10 17:22:57 -04:00
padding = " max_length "
2022-06-14 17:31:12 -04:00
if hf_args and hf_args . pad_to_max_length
2022-05-10 17:22:57 -04:00
else False , # to be consistent with https://github.com/huggingface/transformers/blob/main/examples/pytorch/token-classification/run_ner.py#L394
2022-01-03 13:44:10 -05:00
truncation = True ,
2022-06-14 17:31:12 -04:00
max_length = hf_args . max_seq_length if hf_args else None ,
2022-01-03 13:44:10 -05:00
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words = True ,
)
if Y_sent_key is not None :
previous_word_idx = None
label_ids = [ ]
for word_idx in tokenized_inputs . word_ids ( batch_index = 0 ) :
if word_idx is None :
label_ids . append ( - 100 )
elif word_idx != previous_word_idx :
2022-07-05 13:38:21 -04:00
label_ids . append ( label_to_id [ examples [ Y_sent_key ] [ word_idx ] ] )
2022-01-03 13:44:10 -05:00
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else :
2022-07-05 13:38:21 -04:00
# Use the label_all_tokens to control whether to copy the label to all subtokens or to pad the additional tokens as -100
if hf_args . label_all_tokens :
# If the B- word is converted into multiple subtokens, map the additional subtokens to I-
label_ids . append (
b_to_i_label [ label_to_id [ examples [ Y_sent_key ] [ word_idx ] ] ]
)
else :
label_ids . append ( - 100 )
2022-01-03 13:44:10 -05:00
previous_word_idx = word_idx
2022-05-10 17:22:57 -04:00
tokenized_inputs [ " labels " ] = label_ids
2022-03-20 22:03:02 -04:00
tmp_column_names = sorted ( tokenized_inputs . keys ( ) )
tokenized_input_and_labels = [ tokenized_inputs [ x ] for x in tmp_column_names ]
for key_idx , each_key in enumerate ( tmp_column_names ) :
2022-05-10 17:22:57 -04:00
if each_key != " labels " :
2022-01-03 13:44:10 -05:00
tokenized_input_and_labels [ key_idx ] = tokenized_input_and_labels [ key_idx ] [ 0 ]
2022-03-20 22:03:02 -04:00
if return_column_name :
return tokenized_input_and_labels , tmp_column_names
else :
return tokenized_input_and_labels
2022-01-03 13:44:10 -05:00
2022-03-20 22:03:02 -04:00
def tokenize_text_tokclassification ( X , Y , tokenizer , hf_args = None ) :
2022-07-05 13:38:21 -04:00
# If the label_all_tokens flag is True, prepare two dicts label_to_id and b_to_i_label to convert the B- labels to I- labels
label_to_id = { i : i for i in range ( len ( hf_args . label_list ) ) }
b_to_i_label = [ ]
for idx , label in enumerate ( hf_args . label_list ) :
if label . startswith ( " B- " ) and label . replace ( " B- " , " I- " ) in hf_args . label_list :
b_to_i_label . append ( hf_args . label_list . index ( label . replace ( " B- " , " I- " ) ) )
else :
b_to_i_label . append ( idx )
2022-01-03 13:44:10 -05:00
if Y is not None :
X_and_Y = pd . concat ( [ X , Y . to_frame ( ) ] , axis = 1 )
X_key = list ( X . keys ( ) ) [ 0 ]
Y_key = list ( Y . to_frame ( ) . keys ( ) ) [ 0 ]
2022-03-20 22:03:02 -04:00
_ , tokenized_column_names = tokenize_and_align_labels (
X_and_Y . iloc [ 0 ] ,
tokenizer = tokenizer ,
hf_args = hf_args ,
X_sent_key = X_key ,
Y_sent_key = Y_key ,
return_column_name = True ,
2022-07-05 13:38:21 -04:00
label_to_id = label_to_id ,
b_to_i_label = b_to_i_label ,
2022-03-20 22:03:02 -04:00
)
2022-01-03 13:44:10 -05:00
X_and_Y_tokenized = X_and_Y . apply (
lambda x : tokenize_and_align_labels (
x ,
2022-01-16 12:07:31 -05:00
tokenizer = tokenizer ,
2022-03-20 22:03:02 -04:00
hf_args = hf_args ,
2022-01-03 13:44:10 -05:00
X_sent_key = X_key ,
Y_sent_key = Y_key ,
2022-07-05 13:38:21 -04:00
label_to_id = label_to_id ,
b_to_i_label = b_to_i_label ,
2022-01-03 13:44:10 -05:00
) ,
axis = 1 ,
result_type = " expand " ,
)
2022-05-10 17:22:57 -04:00
label_idx = tokenized_column_names . index ( " labels " )
2022-01-03 13:44:10 -05:00
other_indices = sorted (
set ( range ( len ( tokenized_column_names ) ) ) . difference ( { label_idx } )
)
other_column_names = [ tokenized_column_names [ x ] for x in other_indices ]
d = X_and_Y_tokenized . iloc [ : , other_indices ]
y_tokenized = X_and_Y_tokenized . iloc [ : , label_idx ]
else :
X_key = list ( X . keys ( ) ) [ 0 ]
2022-03-20 22:03:02 -04:00
_ , tokenized_column_names = tokenize_and_align_labels (
X . iloc [ 0 ] ,
tokenizer = tokenizer ,
hf_args = hf_args ,
X_sent_key = X_key ,
Y_sent_key = None ,
return_column_name = True ,
2022-07-05 13:38:21 -04:00
label_to_id = label_to_id ,
b_to_i_label = b_to_i_label ,
2022-03-20 22:03:02 -04:00
)
2022-01-03 13:44:10 -05:00
d = X . apply (
lambda x : tokenize_and_align_labels (
x ,
2022-01-16 12:07:31 -05:00
tokenizer = tokenizer ,
2022-03-20 22:03:02 -04:00
hf_args = hf_args ,
2022-01-03 13:44:10 -05:00
X_sent_key = X_key ,
Y_sent_key = None ,
2022-07-05 13:38:21 -04:00
label_to_id = label_to_id ,
b_to_i_label = b_to_i_label ,
2022-01-03 13:44:10 -05:00
) ,
axis = 1 ,
result_type = " expand " ,
)
other_column_names = tokenized_column_names
y_tokenized = None
X_tokenized = pd . DataFrame ( columns = other_column_names )
X_tokenized [ other_column_names ] = d
return X_tokenized , y_tokenized
2021-12-20 17:19:32 -05:00
def tokenize_onedataframe (
2022-01-03 13:44:10 -05:00
X ,
2022-01-16 12:07:31 -05:00
tokenizer ,
2022-01-03 13:44:10 -05:00
task = None ,
2022-03-20 22:03:02 -04:00
hf_args = None ,
2022-01-16 12:07:31 -05:00
prefix_str = None ,
2021-12-20 17:19:32 -05:00
) :
2021-11-16 14:06:20 -05:00
2022-01-16 12:07:31 -05:00
with tokenizer . as_target_tokenizer ( ) :
2022-03-20 22:03:02 -04:00
_ , tokenized_column_names = tokenize_row (
dict ( X . iloc [ 0 ] ) ,
tokenizer ,
prefix = ( prefix_str , ) if task is SUMMARIZATION else None ,
task = task ,
hf_args = hf_args ,
return_column_name = True ,
)
2021-12-20 17:19:32 -05:00
d = X . apply (
lambda x : tokenize_row (
x ,
2022-01-16 12:07:31 -05:00
tokenizer ,
prefix = ( prefix_str , ) if task is SUMMARIZATION else None ,
2021-12-20 17:19:32 -05:00
task = task ,
2022-03-20 22:03:02 -04:00
hf_args = hf_args ,
2021-12-20 17:19:32 -05:00
) ,
axis = 1 ,
result_type = " expand " ,
)
2022-07-05 13:38:21 -04:00
X_tokenized = pd . DataFrame ( columns = tokenized_column_names )
2022-03-20 22:03:02 -04:00
X_tokenized [ tokenized_column_names ] = d
return X_tokenized
2021-12-20 17:19:32 -05:00
2022-03-20 22:03:02 -04:00
def tokenize_row (
this_row ,
tokenizer ,
prefix = None ,
task = None ,
hf_args = None ,
return_column_name = False ,
) :
2021-12-20 17:19:32 -05:00
if prefix :
this_row = tuple ( [ " " . join ( x ) for x in zip ( prefix , this_row ) ] )
2022-03-20 22:03:02 -04:00
# tokenizer.pad_token = tokenizer.eos_token
2022-01-16 12:07:31 -05:00
tokenized_example = tokenizer (
2021-11-16 14:06:20 -05:00
* tuple ( this_row ) ,
padding = " max_length " ,
2022-03-26 14:08:51 -04:00
max_length = hf_args . max_seq_length if hf_args else None ,
2021-11-16 14:06:20 -05:00
truncation = True ,
)
2021-12-20 17:19:32 -05:00
if task in NLG_TASKS :
tokenized_example [ " decoder_input_ids " ] = tokenized_example [ " input_ids " ]
2022-03-20 22:03:02 -04:00
tmp_column_names = sorted ( tokenized_example . keys ( ) )
2022-03-26 14:08:51 -04:00
2022-03-20 22:03:02 -04:00
if return_column_name :
return [ tokenized_example [ x ] for x in tmp_column_names ] , tmp_column_names
else :
return [ tokenized_example [ x ] for x in tmp_column_names ]
2021-11-16 14:06:20 -05:00
2022-03-20 22:03:02 -04:00
def tokenize_text_multiplechoice ( X , tokenizer , hf_args = None ) :
2022-01-02 20:12:34 -05:00
t = X [ [ " sent1 " , " sent2 " , " ending0 " , " ending1 " , " ending2 " , " ending3 " ] ]
2022-03-20 22:03:02 -04:00
_ , tokenized_column_names = tokenize_swag (
t . iloc [ 0 ] ,
tokenizer = tokenizer ,
hf_args = hf_args ,
return_column_name = True ,
)
2022-01-02 20:12:34 -05:00
d = t . apply (
2022-03-20 22:03:02 -04:00
lambda x : tokenize_swag ( x , tokenizer = tokenizer , hf_args = hf_args ) ,
2022-01-02 20:12:34 -05:00
axis = 1 ,
result_type = " expand " ,
)
2022-07-05 13:38:21 -04:00
X_tokenized = pd . DataFrame ( columns = tokenized_column_names )
2022-01-02 20:12:34 -05:00
X_tokenized [ tokenized_column_names ] = d
output = X_tokenized . join ( X )
return output , None
2022-03-20 22:03:02 -04:00
def tokenize_swag ( this_row , tokenizer , hf_args = None , return_column_name = False ) :
2022-01-02 20:12:34 -05:00
first_sentences = [ [ this_row [ " sent1 " ] ] * 4 ]
# get each 1st sentence, multiply to 4 sentences
question_headers = this_row [ " sent2 " ]
# sent2 are the noun part of 2nd line
second_sentences = [
question_headers + " " + this_row [ key ]
for key in [ " ending0 " , " ending1 " , " ending2 " , " ending3 " ]
]
# now the 2nd-sentences are formed by combing the noun part and 4 ending parts
# Flatten out
# From 2 dimension to 1 dimension array
first_sentences = list ( chain ( * first_sentences ) )
2022-01-16 12:07:31 -05:00
tokenized_example = tokenizer (
2022-01-02 20:12:34 -05:00
* tuple ( [ first_sentences , second_sentences ] ) ,
truncation = True ,
2022-03-26 14:08:51 -04:00
max_length = hf_args . max_seq_length if hf_args else None ,
2022-01-02 20:12:34 -05:00
padding = False ,
)
2022-03-20 22:03:02 -04:00
tmp_column_names = sorted ( tokenized_example . keys ( ) )
if return_column_name :
return [ tokenized_example [ x ] for x in tmp_column_names ] , tmp_column_names
else :
return [ tokenized_example [ x ] for x in tmp_column_names ]
2022-01-02 20:12:34 -05:00
2022-01-03 13:44:10 -05:00
def is_a_list_of_str ( this_obj ) :
2022-03-20 22:03:02 -04:00
return ( isinstance ( this_obj , list ) or isinstance ( this_obj , np . ndarray ) ) and all (
isinstance ( x , str ) for x in this_obj
)
2022-01-03 13:44:10 -05:00
2021-11-23 14:26:39 -05:00
def _clean_value ( value : Any ) - > str :
if isinstance ( value , float ) :
return " {:.5} " . format ( value )
else :
return str ( value ) . replace ( " / " , " _ " )
def format_vars ( resolved_vars : Dict ) - > str :
""" Formats the resolved variable dict into a single string. """
out = [ ]
for path , value in sorted ( resolved_vars . items ( ) ) :
if path [ 0 ] in [ " run " , " env " , " resources_per_trial " ] :
continue # TrialRunner already has these in the experiment_tag
pieces = [ ]
last_string = True
for k in path [ : : - 1 ] :
if isinstance ( k , int ) :
pieces . append ( str ( k ) )
elif last_string :
last_string = False
pieces . append ( k )
pieces . reverse ( )
out . append ( _clean_value ( " _ " . join ( pieces ) ) + " = " + _clean_value ( value ) )
return " , " . join ( out )
counter = 0
def date_str ( ) :
from datetime import datetime
return datetime . today ( ) . strftime ( " % Y- % m- %d _ % H- % M- % S " )
def _generate_dirname ( experiment_tag , trial_id ) :
generated_dirname = f " train_ { str ( trial_id ) } _ { experiment_tag } "
generated_dirname = generated_dirname [ : 130 ]
generated_dirname + = f " _ { date_str ( ) } "
return generated_dirname . replace ( " / " , " _ " )
def get_logdir_name ( dirname , local_dir ) :
import os
local_dir = os . path . expanduser ( local_dir )
logdir = os . path . join ( local_dir , dirname )
return logdir
2022-03-20 22:03:02 -04:00
class Counter :
counter = 0
@staticmethod
def get_trial_fold_name ( local_dir , trial_config , trial_id ) :
Counter . counter + = 1
experiment_tag = " {0} _ {1} " . format (
str ( Counter . counter ) , format_vars ( trial_config )
)
logdir = get_logdir_name (
_generate_dirname ( experiment_tag , trial_id = trial_id ) , local_dir
)
return logdir
2021-11-23 14:26:39 -05:00
2022-04-28 14:06:29 -04:00
def load_model ( checkpoint_path , task , num_labels = None ) :
2022-01-14 00:08:51 -05:00
import transformers
transformers . logging . set_verbosity_error ( )
2021-11-16 14:06:20 -05:00
from transformers import AutoConfig
from . huggingface . switch_head_auto import (
AutoSeqClassificationHead ,
MODEL_CLASSIFICATION_HEAD_MAPPING ,
)
2022-01-03 13:44:10 -05:00
from . . data import SEQCLASSIFICATION , SEQREGRESSION , TOKENCLASSIFICATION
2021-11-16 14:06:20 -05:00
2022-03-26 14:08:51 -04:00
def get_this_model ( checkpoint_path , task , model_config ) :
2021-11-16 14:06:20 -05:00
from transformers import AutoModelForSequenceClassification
2021-12-20 17:19:32 -05:00
from transformers import AutoModelForSeq2SeqLM
2022-01-02 20:12:34 -05:00
from transformers import AutoModelForMultipleChoice
2022-01-03 13:44:10 -05:00
from transformers import AutoModelForTokenClassification
2021-11-16 14:06:20 -05:00
2021-12-03 12:45:16 -05:00
if task in ( SEQCLASSIFICATION , SEQREGRESSION ) :
return AutoModelForSequenceClassification . from_pretrained (
checkpoint_path , config = model_config
)
2022-01-03 13:44:10 -05:00
elif task == TOKENCLASSIFICATION :
return AutoModelForTokenClassification . from_pretrained (
checkpoint_path , config = model_config
)
2021-12-20 17:19:32 -05:00
elif task in NLG_TASKS :
return AutoModelForSeq2SeqLM . from_pretrained (
checkpoint_path , config = model_config
)
2022-01-02 20:12:34 -05:00
elif task == MULTICHOICECLASSIFICATION :
return AutoModelForMultipleChoice . from_pretrained (
checkpoint_path , config = model_config
)
2021-11-16 14:06:20 -05:00
def is_pretrained_model_in_classification_head_list ( model_type ) :
return model_type in MODEL_CLASSIFICATION_HEAD_MAPPING
def _set_model_config ( checkpoint_path ) :
2022-01-03 13:44:10 -05:00
if task in ( SEQCLASSIFICATION , SEQREGRESSION , TOKENCLASSIFICATION ) :
2022-04-28 14:06:29 -04:00
model_config = AutoConfig . from_pretrained (
checkpoint_path ,
num_labels = model_config_num_labels ,
)
2021-12-03 12:45:16 -05:00
return model_config
2021-12-20 17:19:32 -05:00
else :
2022-04-28 14:06:29 -04:00
model_config = AutoConfig . from_pretrained ( checkpoint_path )
2021-12-20 17:19:32 -05:00
return model_config
2021-11-16 14:06:20 -05:00
2022-01-17 14:44:11 -05:00
current_config = AutoConfig . from_pretrained ( checkpoint_path )
this_model_type , this_vocab_size = (
current_config . model_type ,
current_config . vocab_size ,
)
2021-11-16 14:06:20 -05:00
if task == SEQCLASSIFICATION :
2022-01-17 14:44:11 -05:00
num_labels_old = current_config . num_labels
2021-11-16 14:06:20 -05:00
if is_pretrained_model_in_classification_head_list ( this_model_type ) :
model_config_num_labels = num_labels_old
else :
model_config_num_labels = num_labels
2022-01-17 14:44:11 -05:00
new_config = _set_model_config ( checkpoint_path )
2021-11-16 14:06:20 -05:00
if is_pretrained_model_in_classification_head_list ( this_model_type ) :
if num_labels != num_labels_old :
2022-03-26 14:08:51 -04:00
this_model = get_this_model ( checkpoint_path , task , new_config )
2022-01-17 14:44:11 -05:00
new_config . num_labels = num_labels
2021-11-16 14:06:20 -05:00
this_model . num_labels = num_labels
this_model . classifier = (
AutoSeqClassificationHead . from_model_type_and_config (
2022-01-17 14:44:11 -05:00
this_model_type , new_config
2021-11-16 14:06:20 -05:00
)
)
else :
2022-03-26 14:08:51 -04:00
this_model = get_this_model ( checkpoint_path , task , new_config )
2021-11-16 14:06:20 -05:00
else :
2022-03-26 14:08:51 -04:00
this_model = get_this_model ( checkpoint_path , task , new_config )
2021-11-16 14:06:20 -05:00
this_model . resize_token_embeddings ( this_vocab_size )
return this_model
2021-12-03 12:45:16 -05:00
else :
if task == SEQREGRESSION :
model_config_num_labels = 1
2022-01-03 13:44:10 -05:00
elif task == TOKENCLASSIFICATION :
model_config_num_labels = num_labels
2021-11-16 14:06:20 -05:00
model_config = _set_model_config ( checkpoint_path )
2022-03-26 14:08:51 -04:00
this_model = get_this_model ( checkpoint_path , task , model_config )
2021-11-16 14:06:20 -05:00
return this_model
2022-07-05 13:38:21 -04:00
def postprocess_prediction_and_true (
task , y_pred , tokenizer , hf_args , y_true = None , X = None
) :
# postprocess the matrix prediction y_pred and ground truth y_true into user readable format, e.g., for summarization, decode into text
if task == SEQCLASSIFICATION :
return np . argmax ( y_pred , axis = 1 ) , y_true
elif task == SEQREGRESSION :
return np . squeeze ( y_pred ) , y_true # predictions.reshape((len(predictions),))
elif task == TOKENCLASSIFICATION :
assert ( y_true is not None ) or (
X is not None
) , " One of y_true and X must not be None "
## If y_true is not None, we use y_true to remove the -100 in the prediction (postprocessing), and return the postprocessed y_true and prediction
# If y_true is None, we use X to compute y_is_pad (i.e., whether y_true is -100 in that position), and use y_is_pad to remove the -100 in the prediction, and return the postprocessed prediction (not the y_true)
y_predict = pd . Series ( np . argmax ( y_pred , axis = 2 ) . tolist ( ) )
if y_true is None :
_ , y_is_pad = tokenize_text (
X ,
y_predict ,
task = task ,
hf_args = hf_args ,
tokenizer = tokenizer ,
)
else :
y_is_pad = y_true
label_len = len ( hf_args . label_list )
zip_pred_ispad = [
[ ( p , ispd ) for ( p , ispd ) in zip ( each_pred , each_is_pad ) if ispd != - 100 ]
for ( each_pred , each_is_pad ) in zip ( y_predict , y_is_pad )
]
y_pred_label = [
[
hf_args . label_list [ p ] if 0 < = p < label_len else - 1
for ( p , ispd ) in each_list
]
for each_list in zip_pred_ispad
] # To compute precision and recall, y_pred and y_true must be converted to string labels
# (B-PER, I-PER, etc.), so that the category-based precision/recall (i.e., PER, LOC, etc.) scores can be computed
if y_true is not None :
y_true_label = [
[ tr for ( p , tr ) in each_list ] for each_list in zip_pred_ispad
]
else :
y_true_label = None
return y_pred_label , y_true_label
elif task == SUMMARIZATION :
if isinstance ( y_pred , tuple ) :
y_pred = np . argmax ( y_pred [ 0 ] , axis = 2 )
decoded_preds = tokenizer . batch_decode ( y_pred , skip_special_tokens = True )
import nltk
nltk . download ( " punkt " )
decoded_preds = [ pred . strip ( ) for pred in decoded_preds ]
decoded_preds = [ " \n " . join ( nltk . sent_tokenize ( pred ) ) for pred in decoded_preds ]
if y_true is not None :
y_true_labels = np . where ( y_true != - 100 , y_true , tokenizer . pad_token_id )
decoded_y_true_labels = tokenizer . batch_decode (
y_true_labels , skip_special_tokens = True
)
decoded_y_true_labels = [ label . strip ( ) for label in decoded_y_true_labels ]
decoded_y_true_labels = [
" \n " . join ( nltk . sent_tokenize ( label ) ) for label in decoded_y_true_labels
]
else :
decoded_y_true_labels = None
return decoded_preds , decoded_y_true_labels
elif task == MULTICHOICECLASSIFICATION :
return np . argmax ( y_pred , axis = 1 ) , y_true
class LabelEncoderforTokenClassification :
def fit_transform ( self , y ) :
# if the labels are tokens, convert them to ids
if any ( isinstance ( id , str ) for id in y [ 0 ] ) :
self . label_list = sorted ( list ( set ( ) . union ( * y ) ) )
self . _tokenlabel_to_id = {
self . label_list [ id ] : id for id in range ( len ( self . label_list ) )
}
y = y . apply ( lambda sent : [ self . _tokenlabel_to_id [ token ] for token in sent ] )
# if the labels are not tokens, they must be ids
else :
assert all (
isinstance ( id , int ) for id in y [ 0 ]
) , " The labels must either be tokens or ids "
return y
def transform ( self , y ) :
if hasattr ( self , " _tokenlabel_to_id " ) :
y = y . apply ( lambda sent : [ self . _tokenlabel_to_id [ token ] for token in sent ] )
return y