mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-24 01:11:45 +00:00
fix the post-processing bug in NER (#534)
* fix conll bug * update DataCollatorForAuto * adding label_list comments
This commit is contained in:
parent
c1bb66980c
commit
2a8decdc50
@ -1569,7 +1569,10 @@ class AutoML(BaseEstimator):
|
|||||||
```
|
```
|
||||||
|
|
||||||
**fit_kwargs: Other key word arguments to pass to fit() function of
|
**fit_kwargs: Other key word arguments to pass to fit() function of
|
||||||
the searched learners, such as sample_weight.
|
the searched learners, such as sample_weight. Include:
|
||||||
|
period: int | forecast horizon for ts_forecast tasks.
|
||||||
|
gpu_per_trial: float, default = 0 | A float of the number of gpus per trial,
|
||||||
|
only used by TransformersEstimator and XGBoostSklearnEstimator.
|
||||||
"""
|
"""
|
||||||
task = task or self._settings.get("task")
|
task = task or self._settings.get("task")
|
||||||
eval_method = eval_method or self._settings.get("eval_method")
|
eval_method = eval_method or self._settings.get("eval_method")
|
||||||
@ -2198,6 +2201,8 @@ class AutoML(BaseEstimator):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
fit_kwargs_by_estimator: dict, default=None | The user specified keywords arguments, grouped by estimator name.
|
fit_kwargs_by_estimator: dict, default=None | The user specified keywords arguments, grouped by estimator name.
|
||||||
|
For TransformersEstimator, available fit_kwargs can be found from
|
||||||
|
[flaml/nlp/training_args.py:TrainingArgumentsForAuto](https://microsoft.github.io/FLAML/docs/reference/nlp/huggingface/training_args).
|
||||||
e.g.,
|
e.g.,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
34
flaml/ml.py
34
flaml/ml.py
@ -165,7 +165,7 @@ def metric_loss_score(
|
|||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
datasets_metric_name = huggingface_submetric_to_metric.get(
|
datasets_metric_name = huggingface_submetric_to_metric.get(
|
||||||
metric_name, metric_name
|
metric_name, metric_name.split(":")[0]
|
||||||
)
|
)
|
||||||
metric = datasets.load_metric(datasets_metric_name)
|
metric = datasets.load_metric(datasets_metric_name)
|
||||||
metric_mode = huggingface_metric_to_mode[datasets_metric_name]
|
metric_mode = huggingface_metric_to_mode[datasets_metric_name]
|
||||||
@ -174,17 +174,30 @@ def metric_loss_score(
|
|||||||
score = metric.compute(predictions=y_predict, references=y_true)[
|
score = metric.compute(predictions=y_predict, references=y_true)[
|
||||||
metric_name
|
metric_name
|
||||||
].mid.fmeasure
|
].mid.fmeasure
|
||||||
elif metric_name == "seqeval":
|
elif metric_name.startswith("seqeval"):
|
||||||
y_true = [
|
|
||||||
[x for x in each_y_true if x != -100] for each_y_true in y_true
|
zip_pred_true = [
|
||||||
|
[(p, lb) for (p, lb) in zip(prediction, label) if lb != -100]
|
||||||
|
for (prediction, label) in zip(y_predict, y_true)
|
||||||
]
|
]
|
||||||
y_pred = [
|
y_pred = [
|
||||||
y_predict[each_idx][: len(y_true[each_idx])]
|
[labels[p] for (p, l) in each_list]
|
||||||
for each_idx in range(len(y_predict))
|
for each_list in zip_pred_true
|
||||||
|
] # To compute precision and recall, y_pred and y_true must be converted to string labels
|
||||||
|
# (B-PER, I-PER, etc.), so that the category-based precision/recall (i.e., PER, LOC, etc.) scores can be computed
|
||||||
|
y_true = [
|
||||||
|
[labels[l] for (p, l) in each_list]
|
||||||
|
for each_list in zip_pred_true
|
||||||
]
|
]
|
||||||
|
|
||||||
|
metric_submetric_names = metric_name.split(":")
|
||||||
|
|
||||||
score = metric.compute(predictions=y_pred, references=y_true)[
|
score = metric.compute(predictions=y_pred, references=y_true)[
|
||||||
"overall_accuracy"
|
metric_submetric_names[1]
|
||||||
|
if len(metric_submetric_names) > 1
|
||||||
|
else "overall_accuracy"
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
score = metric.compute(predictions=y_predict, references=y_true)[
|
score = metric.compute(predictions=y_predict, references=y_true)[
|
||||||
metric_name
|
metric_name
|
||||||
@ -454,7 +467,9 @@ def evaluate_model_CV(
|
|||||||
if task in CLASSIFICATION:
|
if task in CLASSIFICATION:
|
||||||
labels = np.unique(y_train_all)
|
labels = np.unique(y_train_all)
|
||||||
else:
|
else:
|
||||||
labels = None
|
labels = fit_kwargs.get(
|
||||||
|
"label_list"
|
||||||
|
) # pass the label list on to compute the evaluation metric
|
||||||
groups = None
|
groups = None
|
||||||
shuffle = False if task in TS_FORECAST else True
|
shuffle = False if task in TS_FORECAST else True
|
||||||
if isinstance(kf, RepeatedStratifiedKFold):
|
if isinstance(kf, RepeatedStratifiedKFold):
|
||||||
@ -586,6 +601,9 @@ def compute_estimator(
|
|||||||
groups_val,
|
groups_val,
|
||||||
eval_metric,
|
eval_metric,
|
||||||
task,
|
task,
|
||||||
|
labels=fit_kwargs.get(
|
||||||
|
"label_list"
|
||||||
|
), # pass the label list on to compute the evaluation metric
|
||||||
budget=budget,
|
budget=budget,
|
||||||
log_training_metric=log_training_metric,
|
log_training_metric=log_training_metric,
|
||||||
fit_kwargs=fit_kwargs,
|
fit_kwargs=fit_kwargs,
|
||||||
|
@ -365,9 +365,9 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
self._TrainingArguments = TrainingArguments
|
self._TrainingArguments = TrainingArguments
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _join(X_train, y_train):
|
def _join(X_train, y_train, task):
|
||||||
y_train = DataFrame(y_train, index=X_train.index)
|
y_train = DataFrame(y_train, index=X_train.index)
|
||||||
y_train.columns = ["label"]
|
y_train.columns = ["label"] if task != TOKENCLASSIFICATION else ["labels"]
|
||||||
train_df = X_train.join(y_train)
|
train_df = X_train.join(y_train)
|
||||||
return train_df
|
return train_df
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
},
|
},
|
||||||
"num_train_epochs": {
|
"num_train_epochs": {
|
||||||
"domain": tune.loguniform(lower=0.1, upper=10.0),
|
"domain": tune.loguniform(lower=0.1, upper=10.0),
|
||||||
"init_value": 1,
|
"init_value": 3.0, # to be consistent with roberta
|
||||||
},
|
},
|
||||||
"per_device_train_batch_size": {
|
"per_device_train_batch_size": {
|
||||||
"domain": tune.choice([4, 8, 16, 32]),
|
"domain": tune.choice([4, 8, 16, 32]),
|
||||||
@ -511,7 +511,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
processed_X, processed_y = self._preprocess(X=X, y=y, **self._kwargs)
|
processed_X, processed_y = self._preprocess(X=X, y=y, **self._kwargs)
|
||||||
|
|
||||||
processed_dataset = Dataset.from_pandas(
|
processed_dataset = Dataset.from_pandas(
|
||||||
TransformersEstimator._join(processed_X, processed_y)
|
TransformersEstimator._join(processed_X, processed_y, self._task)
|
||||||
)
|
)
|
||||||
return processed_dataset, processed_X, processed_y
|
return processed_dataset, processed_X, processed_y
|
||||||
|
|
||||||
@ -547,14 +547,14 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def data_collator(self):
|
def data_collator(self):
|
||||||
from .nlp.huggingface.data_collator import DataCollatorForAuto
|
from .nlp.huggingface.data_collator import task_to_datacollator_class
|
||||||
|
|
||||||
return (
|
return (
|
||||||
DataCollatorForAuto(
|
task_to_datacollator_class[self._task](
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
pad_to_multiple_of=8 if self._training_args.fp16 else None,
|
pad_to_multiple_of=8, # if self._training_args.fp16 else None,
|
||||||
)
|
)
|
||||||
if self._task == MULTICHOICECLASSIFICATION
|
if self._task in (MULTICHOICECLASSIFICATION, TOKENCLASSIFICATION)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -750,7 +750,10 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
)
|
)
|
||||||
metric_dict = {
|
metric_dict = {
|
||||||
"automl_metric": metric_loss_score(
|
"automl_metric": metric_loss_score(
|
||||||
metric_name=self._metric, y_predict=predictions, y_true=labels
|
metric_name=self._metric,
|
||||||
|
y_predict=predictions,
|
||||||
|
y_true=labels,
|
||||||
|
labels=self._training_args.label_list,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers.data.data_collator import DataCollatorWithPadding
|
from transformers.data.data_collator import (
|
||||||
|
DataCollatorWithPadding,
|
||||||
|
DataCollatorForTokenClassification,
|
||||||
|
)
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from flaml.data import TOKENCLASSIFICATION, MULTICHOICECLASSIFICATION
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForAuto(DataCollatorWithPadding):
|
class DataCollatorForMultipleChoiceClassification(DataCollatorWithPadding):
|
||||||
def __call__(self, features):
|
def __call__(self, features):
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
import torch
|
import torch
|
||||||
@ -22,7 +28,9 @@ class DataCollatorForAuto(DataCollatorWithPadding):
|
|||||||
for feature in features
|
for feature in features
|
||||||
]
|
]
|
||||||
flattened_features = list(chain(*flattened_features))
|
flattened_features = list(chain(*flattened_features))
|
||||||
batch = super(DataCollatorForAuto, self).__call__(flattened_features)
|
batch = super(DataCollatorForMultipleChoiceClassification, self).__call__(
|
||||||
|
flattened_features
|
||||||
|
)
|
||||||
# Un-flatten
|
# Un-flatten
|
||||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||||
# Add back labels
|
# Add back labels
|
||||||
@ -31,18 +39,9 @@ class DataCollatorForAuto(DataCollatorWithPadding):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class DataCollatorForPredict(DataCollatorWithPadding):
|
task_to_datacollator_class = OrderedDict(
|
||||||
def __call__(self, features):
|
[
|
||||||
from itertools import chain
|
(TOKENCLASSIFICATION, DataCollatorForTokenClassification),
|
||||||
|
(MULTICHOICECLASSIFICATION, DataCollatorForMultipleChoiceClassification),
|
||||||
batch_size = len(features)
|
]
|
||||||
num_choices = len(features[0]["input_ids"])
|
)
|
||||||
flattened_features = [
|
|
||||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)]
|
|
||||||
for feature in features
|
|
||||||
]
|
|
||||||
flattened_features = list(chain(*flattened_features))
|
|
||||||
batch = super(DataCollatorForPredict, self).__call__(flattened_features)
|
|
||||||
# Un-flatten
|
|
||||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
|
||||||
return batch
|
|
||||||
|
@ -17,13 +17,21 @@ class TrainingArgumentsForAuto(TrainingArguments):
|
|||||||
"""FLAML custom TrainingArguments.
|
"""FLAML custom TrainingArguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
task (str): the task name for NLP tasks, e.g., seq-classification, token-classification
|
||||||
output_dir (str): data root directory for outputing the log, etc.
|
output_dir (str): data root directory for outputing the log, etc.
|
||||||
model_path (str, optional, defaults to "facebook/muppet-roberta-base"): A string,
|
model_path (str, optional, defaults to "facebook/muppet-roberta-base"): A string,
|
||||||
the path of the language model file, either a path from huggingface
|
the path of the language model file, either a path from huggingface
|
||||||
model card huggingface.co/models, or a local path for the model.
|
model card huggingface.co/models, or a local path for the model.
|
||||||
fp16 (bool, optional, defaults to "False"): A bool, whether to use FP16.
|
fp16 (bool, optional, defaults to "False"): A bool, whether to use FP16.
|
||||||
max_seq_length (int, optional, defaults to 128): An integer, the max length of the sequence.
|
max_seq_length (int, optional, defaults to 128): An integer, the max length of the sequence.
|
||||||
|
pad_to_max_length (bool, optional, defaults to "False"):
|
||||||
|
whether to pad all samples to model maximum sentence length.
|
||||||
|
If False, will pad the samples dynamically when batching to the maximum length in the batch.
|
||||||
ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch.
|
ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch.
|
||||||
|
per_device_eval_batch_size (int, optional, defaults to 1): An integer, the per gpu evaluation batch size.
|
||||||
|
label_list (List[str], optional, defaults to None): A list of string, the string list of the label names.
|
||||||
|
When the task is sequence labeling/token classification, need to set the label_list (e.g., B-PER, I-PER, B-LOC)
|
||||||
|
to obtain the correct evaluation metric. See the example in test/nlp/test_autohf_tokenclassification.py.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task: str = field(default="seq-classification")
|
task: str = field(default="seq-classification")
|
||||||
@ -37,21 +45,15 @@ class TrainingArgumentsForAuto(TrainingArguments):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer_model_path: str = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "tokenizer model path for HPO"},
|
|
||||||
)
|
|
||||||
|
|
||||||
fp16: bool = field(default=True, metadata={"help": "whether to use the FP16 mode"})
|
fp16: bool = field(default=True, metadata={"help": "whether to use the FP16 mode"})
|
||||||
|
|
||||||
max_seq_length: int = field(default=128, metadata={"help": "max seq length"})
|
max_seq_length: int = field(default=128, metadata={"help": "max seq length"})
|
||||||
|
|
||||||
pad_to_max_length: bool = field(
|
pad_to_max_length: bool = field(
|
||||||
default=True,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. "
|
||||||
"efficient on GPU but very bad for TPU."
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,21 +64,8 @@ class TrainingArgumentsForAuto(TrainingArguments):
|
|||||||
metadata={"help": "per gpu evaluation batch size"},
|
metadata={"help": "per gpu evaluation batch size"},
|
||||||
)
|
)
|
||||||
|
|
||||||
report_to: Optional[List[str]] = field(
|
label_list: Optional[List[str]] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "The string list of the label names. "}
|
||||||
metadata={
|
|
||||||
"help": "The list of integrations to report the results and logs to."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
|
||||||
do_eval: bool = field(
|
|
||||||
default=False, metadata={"help": "Whether to run eval on the dev set."}
|
|
||||||
)
|
|
||||||
|
|
||||||
metric_for_best_model: Optional[str] = field(
|
|
||||||
default="loss",
|
|
||||||
metadata={"help": "The metric to use to compare two different models."},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -83,7 +83,9 @@ def tokenize_and_align_labels(
|
|||||||
):
|
):
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
[list(examples[X_sent_key])],
|
[list(examples[X_sent_key])],
|
||||||
padding="max_length",
|
padding="max_length"
|
||||||
|
if hf_args.pad_to_max_length
|
||||||
|
else False, # to be consistent with https://github.com/huggingface/transformers/blob/main/examples/pytorch/token-classification/run_ner.py#L394
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=hf_args.max_seq_length,
|
max_length=hf_args.max_seq_length,
|
||||||
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
|
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
|
||||||
@ -113,11 +115,11 @@ def tokenize_and_align_labels(
|
|||||||
# else:
|
# else:
|
||||||
# label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
|
# label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
|
||||||
previous_word_idx = word_idx
|
previous_word_idx = word_idx
|
||||||
tokenized_inputs["label"] = label_ids
|
tokenized_inputs["labels"] = label_ids
|
||||||
tmp_column_names = sorted(tokenized_inputs.keys())
|
tmp_column_names = sorted(tokenized_inputs.keys())
|
||||||
tokenized_input_and_labels = [tokenized_inputs[x] for x in tmp_column_names]
|
tokenized_input_and_labels = [tokenized_inputs[x] for x in tmp_column_names]
|
||||||
for key_idx, each_key in enumerate(tmp_column_names):
|
for key_idx, each_key in enumerate(tmp_column_names):
|
||||||
if each_key != "label":
|
if each_key != "labels":
|
||||||
tokenized_input_and_labels[key_idx] = tokenized_input_and_labels[key_idx][0]
|
tokenized_input_and_labels[key_idx] = tokenized_input_and_labels[key_idx][0]
|
||||||
if return_column_name:
|
if return_column_name:
|
||||||
return tokenized_input_and_labels, tmp_column_names
|
return tokenized_input_and_labels, tmp_column_names
|
||||||
@ -151,7 +153,7 @@ def tokenize_text_tokclassification(X, Y, tokenizer, hf_args=None):
|
|||||||
axis=1,
|
axis=1,
|
||||||
result_type="expand",
|
result_type="expand",
|
||||||
)
|
)
|
||||||
label_idx = tokenized_column_names.index("label")
|
label_idx = tokenized_column_names.index("labels")
|
||||||
other_indices = sorted(
|
other_indices = sorted(
|
||||||
set(range(len(tokenized_column_names))).difference({label_idx})
|
set(range(len(tokenized_column_names))).difference({label_idx})
|
||||||
)
|
)
|
||||||
|
@ -347,3 +347,4 @@ class TestClassification(unittest.TestCase):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
test = TestClassification()
|
||||||
|
@ -13,7 +13,18 @@ def test_tokenclassification():
|
|||||||
|
|
||||||
automl_settings = get_automl_settings()
|
automl_settings = get_automl_settings()
|
||||||
automl_settings["task"] = "token-classification"
|
automl_settings["task"] = "token-classification"
|
||||||
automl_settings["metric"] = "seqeval"
|
automl_settings["metric"] = "seqeval:overall_f1" # evaluating based on the overall_f1 of seqeval
|
||||||
|
automl_settings["fit_kwargs_by_estimator"]["transformer"]["label_list"] = [
|
||||||
|
"O",
|
||||||
|
"B-PER",
|
||||||
|
"I-PER",
|
||||||
|
"B-ORG",
|
||||||
|
"I-ORG",
|
||||||
|
"B-LOC",
|
||||||
|
"I-LOC",
|
||||||
|
"B-MISC",
|
||||||
|
"I-MISC",
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
automl.fit(
|
automl.fit(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user