From 2a8decdc50bb5da054c20fb91760d30e445a77a0 Mon Sep 17 00:00:00 2001 From: Xueqing Liu Date: Tue, 10 May 2022 17:22:57 -0400 Subject: [PATCH] fix the post-processing bug in NER (#534) * fix conll bug * update DataCollatorForAuto * adding label_list comments --- flaml/automl.py | 7 ++++- flaml/ml.py | 34 +++++++++++++++----- flaml/model.py | 21 +++++++------ flaml/nlp/huggingface/data_collator.py | 35 ++++++++++----------- flaml/nlp/huggingface/training_args.py | 35 +++++++-------------- flaml/nlp/utils.py | 10 +++--- test/automl/test_classification.py | 1 + test/nlp/test_autohf_tokenclassification.py | 13 +++++++- 8 files changed, 92 insertions(+), 64 deletions(-) diff --git a/flaml/automl.py b/flaml/automl.py index 90d9ebf94..cc455237b 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -1569,7 +1569,10 @@ class AutoML(BaseEstimator): ``` **fit_kwargs: Other key word arguments to pass to fit() function of - the searched learners, such as sample_weight. + the searched learners, such as sample_weight. Include: + period: int | forecast horizon for ts_forecast tasks. + gpu_per_trial: float, default = 0 | A float of the number of gpus per trial, + only used by TransformersEstimator and XGBoostSklearnEstimator. """ task = task or self._settings.get("task") eval_method = eval_method or self._settings.get("eval_method") @@ -2198,6 +2201,8 @@ class AutoML(BaseEstimator): } } fit_kwargs_by_estimator: dict, default=None | The user specified keywords arguments, grouped by estimator name. + For TransformersEstimator, available fit_kwargs can be found from + [flaml/nlp/training_args.py:TrainingArgumentsForAuto](https://microsoft.github.io/FLAML/docs/reference/nlp/huggingface/training_args). e.g., ```python diff --git a/flaml/ml.py b/flaml/ml.py index 51b6fb600..092a02565 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -165,7 +165,7 @@ def metric_loss_score( import datasets datasets_metric_name = huggingface_submetric_to_metric.get( - metric_name, metric_name + metric_name, metric_name.split(":")[0] ) metric = datasets.load_metric(datasets_metric_name) metric_mode = huggingface_metric_to_mode[datasets_metric_name] @@ -174,17 +174,30 @@ def metric_loss_score( score = metric.compute(predictions=y_predict, references=y_true)[ metric_name ].mid.fmeasure - elif metric_name == "seqeval": - y_true = [ - [x for x in each_y_true if x != -100] for each_y_true in y_true + elif metric_name.startswith("seqeval"): + + zip_pred_true = [ + [(p, lb) for (p, lb) in zip(prediction, label) if lb != -100] + for (prediction, label) in zip(y_predict, y_true) ] y_pred = [ - y_predict[each_idx][: len(y_true[each_idx])] - for each_idx in range(len(y_predict)) + [labels[p] for (p, l) in each_list] + for each_list in zip_pred_true + ] # To compute precision and recall, y_pred and y_true must be converted to string labels + # (B-PER, I-PER, etc.), so that the category-based precision/recall (i.e., PER, LOC, etc.) scores can be computed + y_true = [ + [labels[l] for (p, l) in each_list] + for each_list in zip_pred_true ] + + metric_submetric_names = metric_name.split(":") + score = metric.compute(predictions=y_pred, references=y_true)[ - "overall_accuracy" + metric_submetric_names[1] + if len(metric_submetric_names) > 1 + else "overall_accuracy" ] + else: score = metric.compute(predictions=y_predict, references=y_true)[ metric_name @@ -454,7 +467,9 @@ def evaluate_model_CV( if task in CLASSIFICATION: labels = np.unique(y_train_all) else: - labels = None + labels = fit_kwargs.get( + "label_list" + ) # pass the label list on to compute the evaluation metric groups = None shuffle = False if task in TS_FORECAST else True if isinstance(kf, RepeatedStratifiedKFold): @@ -586,6 +601,9 @@ def compute_estimator( groups_val, eval_metric, task, + labels=fit_kwargs.get( + "label_list" + ), # pass the label list on to compute the evaluation metric budget=budget, log_training_metric=log_training_metric, fit_kwargs=fit_kwargs, diff --git a/flaml/model.py b/flaml/model.py index ec3ecad0f..0ecac416f 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -365,9 +365,9 @@ class TransformersEstimator(BaseEstimator): self._TrainingArguments = TrainingArguments @staticmethod - def _join(X_train, y_train): + def _join(X_train, y_train, task): y_train = DataFrame(y_train, index=X_train.index) - y_train.columns = ["label"] + y_train.columns = ["label"] if task != TOKENCLASSIFICATION else ["labels"] train_df = X_train.join(y_train) return train_df @@ -380,7 +380,7 @@ class TransformersEstimator(BaseEstimator): }, "num_train_epochs": { "domain": tune.loguniform(lower=0.1, upper=10.0), - "init_value": 1, + "init_value": 3.0, # to be consistent with roberta }, "per_device_train_batch_size": { "domain": tune.choice([4, 8, 16, 32]), @@ -511,7 +511,7 @@ class TransformersEstimator(BaseEstimator): processed_X, processed_y = self._preprocess(X=X, y=y, **self._kwargs) processed_dataset = Dataset.from_pandas( - TransformersEstimator._join(processed_X, processed_y) + TransformersEstimator._join(processed_X, processed_y, self._task) ) return processed_dataset, processed_X, processed_y @@ -547,14 +547,14 @@ class TransformersEstimator(BaseEstimator): @property def data_collator(self): - from .nlp.huggingface.data_collator import DataCollatorForAuto + from .nlp.huggingface.data_collator import task_to_datacollator_class return ( - DataCollatorForAuto( + task_to_datacollator_class[self._task]( tokenizer=self.tokenizer, - pad_to_multiple_of=8 if self._training_args.fp16 else None, + pad_to_multiple_of=8, # if self._training_args.fp16 else None, ) - if self._task == MULTICHOICECLASSIFICATION + if self._task in (MULTICHOICECLASSIFICATION, TOKENCLASSIFICATION) else None ) @@ -750,7 +750,10 @@ class TransformersEstimator(BaseEstimator): ) metric_dict = { "automl_metric": metric_loss_score( - metric_name=self._metric, y_predict=predictions, y_true=labels + metric_name=self._metric, + y_predict=predictions, + y_true=labels, + labels=self._training_args.label_list, ) } else: diff --git a/flaml/nlp/huggingface/data_collator.py b/flaml/nlp/huggingface/data_collator.py index bc8284c6a..1203a536c 100644 --- a/flaml/nlp/huggingface/data_collator.py +++ b/flaml/nlp/huggingface/data_collator.py @@ -1,9 +1,15 @@ from dataclasses import dataclass -from transformers.data.data_collator import DataCollatorWithPadding +from transformers.data.data_collator import ( + DataCollatorWithPadding, + DataCollatorForTokenClassification, +) +from collections import OrderedDict + +from flaml.data import TOKENCLASSIFICATION, MULTICHOICECLASSIFICATION @dataclass -class DataCollatorForAuto(DataCollatorWithPadding): +class DataCollatorForMultipleChoiceClassification(DataCollatorWithPadding): def __call__(self, features): from itertools import chain import torch @@ -22,7 +28,9 @@ class DataCollatorForAuto(DataCollatorWithPadding): for feature in features ] flattened_features = list(chain(*flattened_features)) - batch = super(DataCollatorForAuto, self).__call__(flattened_features) + batch = super(DataCollatorForMultipleChoiceClassification, self).__call__( + flattened_features + ) # Un-flatten batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} # Add back labels @@ -31,18 +39,9 @@ class DataCollatorForAuto(DataCollatorWithPadding): return batch -class DataCollatorForPredict(DataCollatorWithPadding): - def __call__(self, features): - from itertools import chain - - batch_size = len(features) - num_choices = len(features[0]["input_ids"]) - flattened_features = [ - [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] - for feature in features - ] - flattened_features = list(chain(*flattened_features)) - batch = super(DataCollatorForPredict, self).__call__(flattened_features) - # Un-flatten - batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()} - return batch +task_to_datacollator_class = OrderedDict( + [ + (TOKENCLASSIFICATION, DataCollatorForTokenClassification), + (MULTICHOICECLASSIFICATION, DataCollatorForMultipleChoiceClassification), + ] +) diff --git a/flaml/nlp/huggingface/training_args.py b/flaml/nlp/huggingface/training_args.py index 3f5efe61a..e65b3bb0a 100644 --- a/flaml/nlp/huggingface/training_args.py +++ b/flaml/nlp/huggingface/training_args.py @@ -17,13 +17,21 @@ class TrainingArgumentsForAuto(TrainingArguments): """FLAML custom TrainingArguments. Args: + task (str): the task name for NLP tasks, e.g., seq-classification, token-classification output_dir (str): data root directory for outputing the log, etc. model_path (str, optional, defaults to "facebook/muppet-roberta-base"): A string, the path of the language model file, either a path from huggingface model card huggingface.co/models, or a local path for the model. fp16 (bool, optional, defaults to "False"): A bool, whether to use FP16. max_seq_length (int, optional, defaults to 128): An integer, the max length of the sequence. + pad_to_max_length (bool, optional, defaults to "False"): + whether to pad all samples to model maximum sentence length. + If False, will pad the samples dynamically when batching to the maximum length in the batch. ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch. + per_device_eval_batch_size (int, optional, defaults to 1): An integer, the per gpu evaluation batch size. + label_list (List[str], optional, defaults to None): A list of string, the string list of the label names. + When the task is sequence labeling/token classification, need to set the label_list (e.g., B-PER, I-PER, B-LOC) + to obtain the correct evaluation metric. See the example in test/nlp/test_autohf_tokenclassification.py. """ task: str = field(default="seq-classification") @@ -37,21 +45,15 @@ class TrainingArgumentsForAuto(TrainingArguments): }, ) - tokenizer_model_path: str = field( - default=None, - metadata={"help": "tokenizer model path for HPO"}, - ) - fp16: bool = field(default=True, metadata={"help": "whether to use the FP16 mode"}) max_seq_length: int = field(default=128, metadata={"help": "max seq length"}) pad_to_max_length: bool = field( - default=True, + default=False, metadata={ "help": "Whether to pad all samples to model maximum sentence length. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " - "efficient on GPU but very bad for TPU." + "If False, will pad the samples dynamically when batching to the maximum length in the batch. " }, ) @@ -62,21 +64,8 @@ class TrainingArgumentsForAuto(TrainingArguments): metadata={"help": "per gpu evaluation batch size"}, ) - report_to: Optional[List[str]] = field( - default=None, - metadata={ - "help": "The list of integrations to report the results and logs to." - }, - ) - - do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) - do_eval: bool = field( - default=False, metadata={"help": "Whether to run eval on the dev set."} - ) - - metric_for_best_model: Optional[str] = field( - default="loss", - metadata={"help": "The metric to use to compare two different models."}, + label_list: Optional[List[str]] = field( + default=None, metadata={"help": "The string list of the label names. "} ) @staticmethod diff --git a/flaml/nlp/utils.py b/flaml/nlp/utils.py index ab26eb803..cd2e7a409 100644 --- a/flaml/nlp/utils.py +++ b/flaml/nlp/utils.py @@ -83,7 +83,9 @@ def tokenize_and_align_labels( ): tokenized_inputs = tokenizer( [list(examples[X_sent_key])], - padding="max_length", + padding="max_length" + if hf_args.pad_to_max_length + else False, # to be consistent with https://github.com/huggingface/transformers/blob/main/examples/pytorch/token-classification/run_ner.py#L394 truncation=True, max_length=hf_args.max_seq_length, # We use this argument because the texts in our dataset are lists of words (with a label for each word). @@ -113,11 +115,11 @@ def tokenize_and_align_labels( # else: # label_ids.append(b_to_i_label[label_to_id[label[word_idx]]]) previous_word_idx = word_idx - tokenized_inputs["label"] = label_ids + tokenized_inputs["labels"] = label_ids tmp_column_names = sorted(tokenized_inputs.keys()) tokenized_input_and_labels = [tokenized_inputs[x] for x in tmp_column_names] for key_idx, each_key in enumerate(tmp_column_names): - if each_key != "label": + if each_key != "labels": tokenized_input_and_labels[key_idx] = tokenized_input_and_labels[key_idx][0] if return_column_name: return tokenized_input_and_labels, tmp_column_names @@ -151,7 +153,7 @@ def tokenize_text_tokclassification(X, Y, tokenizer, hf_args=None): axis=1, result_type="expand", ) - label_idx = tokenized_column_names.index("label") + label_idx = tokenized_column_names.index("labels") other_indices = sorted( set(range(len(tokenized_column_names))).difference({label_idx}) ) diff --git a/test/automl/test_classification.py b/test/automl/test_classification.py index 6ca9db785..cb09b13ff 100644 --- a/test/automl/test_classification.py +++ b/test/automl/test_classification.py @@ -347,3 +347,4 @@ class TestClassification(unittest.TestCase): if __name__ == "__main__": unittest.main() + test = TestClassification() diff --git a/test/nlp/test_autohf_tokenclassification.py b/test/nlp/test_autohf_tokenclassification.py index 9fb3a663a..7636193ce 100644 --- a/test/nlp/test_autohf_tokenclassification.py +++ b/test/nlp/test_autohf_tokenclassification.py @@ -13,7 +13,18 @@ def test_tokenclassification(): automl_settings = get_automl_settings() automl_settings["task"] = "token-classification" - automl_settings["metric"] = "seqeval" + automl_settings["metric"] = "seqeval:overall_f1" # evaluating based on the overall_f1 of seqeval + automl_settings["fit_kwargs_by_estimator"]["transformer"]["label_list"] = [ + "O", + "B-PER", + "I-PER", + "B-ORG", + "I-ORG", + "B-LOC", + "I-LOC", + "B-MISC", + "I-MISC", + ] try: automl.fit(