mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-11 23:54:52 +00:00
parent
8602def1c4
commit
207b6935d9
@ -40,6 +40,7 @@ from .config import (
|
|||||||
from .data import (
|
from .data import (
|
||||||
concat,
|
concat,
|
||||||
CLASSIFICATION,
|
CLASSIFICATION,
|
||||||
|
TOKENCLASSIFICATION,
|
||||||
TS_FORECAST,
|
TS_FORECAST,
|
||||||
FORECAST,
|
FORECAST,
|
||||||
REGRESSION,
|
REGRESSION,
|
||||||
@ -866,6 +867,8 @@ class AutoML(BaseEstimator):
|
|||||||
|
|
||||||
# check the validity of input dimensions under the nlp mode
|
# check the validity of input dimensions under the nlp mode
|
||||||
if _is_nlp_task(self._state.task):
|
if _is_nlp_task(self._state.task):
|
||||||
|
from .nlp.utils import is_a_list_of_str
|
||||||
|
|
||||||
is_all_str = True
|
is_all_str = True
|
||||||
is_all_list = True
|
is_all_list = True
|
||||||
for column in X.columns:
|
for column in X.columns:
|
||||||
@ -874,17 +877,25 @@ class AutoML(BaseEstimator):
|
|||||||
"string",
|
"string",
|
||||||
), "If the task is an NLP task, X can only contain text columns"
|
), "If the task is an NLP task, X can only contain text columns"
|
||||||
for each_cell in X[column]:
|
for each_cell in X[column]:
|
||||||
if each_cell:
|
if each_cell is not None:
|
||||||
is_str = isinstance(each_cell, str)
|
is_str = isinstance(each_cell, str)
|
||||||
is_list_of_int = isinstance(each_cell, list) and all(
|
is_list_of_int = isinstance(each_cell, list) and all(
|
||||||
isinstance(x, int) for x in each_cell
|
isinstance(x, int) for x in each_cell
|
||||||
)
|
)
|
||||||
|
is_list_of_str = is_a_list_of_str(each_cell)
|
||||||
|
if self._state.task == TOKENCLASSIFICATION:
|
||||||
|
assert is_list_of_str, (
|
||||||
|
"For the token-classification task, the input column needs to be a list of string,"
|
||||||
|
"instead of string, e.g., ['EU', 'rejects','German', 'call','to','boycott','British','lamb','.',].",
|
||||||
|
"For more examples, please refer to test/nlp/test_autohf_tokenclassification.py",
|
||||||
|
)
|
||||||
|
else:
|
||||||
assert is_str or is_list_of_int, (
|
assert is_str or is_list_of_int, (
|
||||||
"Each column of the input must either be str (untokenized) "
|
"Each column of the input must either be str (untokenized) "
|
||||||
"or a list of integers (tokenized)"
|
"or a list of integers (tokenized)"
|
||||||
)
|
)
|
||||||
is_all_str &= is_str
|
is_all_str &= is_str
|
||||||
is_all_list &= is_list_of_int
|
is_all_list &= is_list_of_int or is_list_of_str
|
||||||
assert is_all_str or is_all_list, (
|
assert is_all_str or is_all_list, (
|
||||||
"Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
|
"Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
|
||||||
"or all columns of X are integer ids (tokenized)"
|
"or all columns of X are integer ids (tokenized)"
|
||||||
@ -963,6 +974,7 @@ class AutoML(BaseEstimator):
|
|||||||
and self._auto_augment
|
and self._auto_augment
|
||||||
and self._state.fit_kwargs.get("sample_weight") is None
|
and self._state.fit_kwargs.get("sample_weight") is None
|
||||||
and self._split_type in ["stratified", "uniform"]
|
and self._split_type in ["stratified", "uniform"]
|
||||||
|
and self._state.task != TOKENCLASSIFICATION
|
||||||
):
|
):
|
||||||
# logger.info(f"label {pd.unique(y_train_all)}")
|
# logger.info(f"label {pd.unique(y_train_all)}")
|
||||||
label_set, counts = np.unique(y_train_all, return_counts=True)
|
label_set, counts = np.unique(y_train_all, return_counts=True)
|
||||||
|
|||||||
@ -15,12 +15,14 @@ from typing import Dict, Union, List
|
|||||||
# TODO: if your task is not specified in here, define your task as an all-capitalized word
|
# TODO: if your task is not specified in here, define your task as an all-capitalized word
|
||||||
SEQCLASSIFICATION = "seq-classification"
|
SEQCLASSIFICATION = "seq-classification"
|
||||||
MULTICHOICECLASSIFICATION = "multichoice-classification"
|
MULTICHOICECLASSIFICATION = "multichoice-classification"
|
||||||
|
TOKENCLASSIFICATION = "token-classification"
|
||||||
CLASSIFICATION = (
|
CLASSIFICATION = (
|
||||||
"binary",
|
"binary",
|
||||||
"multi",
|
"multi",
|
||||||
"classification",
|
"classification",
|
||||||
SEQCLASSIFICATION,
|
SEQCLASSIFICATION,
|
||||||
MULTICHOICECLASSIFICATION,
|
MULTICHOICECLASSIFICATION,
|
||||||
|
TOKENCLASSIFICATION,
|
||||||
)
|
)
|
||||||
SEQREGRESSION = "seq-regression"
|
SEQREGRESSION = "seq-regression"
|
||||||
REGRESSION = ("regression", SEQREGRESSION)
|
REGRESSION = ("regression", SEQREGRESSION)
|
||||||
@ -34,6 +36,7 @@ NLU_TASKS = (
|
|||||||
SEQREGRESSION,
|
SEQREGRESSION,
|
||||||
SEQCLASSIFICATION,
|
SEQCLASSIFICATION,
|
||||||
MULTICHOICECLASSIFICATION,
|
MULTICHOICECLASSIFICATION,
|
||||||
|
TOKENCLASSIFICATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -354,11 +357,10 @@ class DataTransformer:
|
|||||||
datetime_columns,
|
datetime_columns,
|
||||||
)
|
)
|
||||||
self._drop = drop
|
self._drop = drop
|
||||||
|
|
||||||
if (
|
if (
|
||||||
task in CLASSIFICATION
|
(task in CLASSIFICATION or not pd.api.types.is_numeric_dtype(y))
|
||||||
or not pd.api.types.is_numeric_dtype(y)
|
|
||||||
and task not in NLG_TASKS
|
and task not in NLG_TASKS
|
||||||
|
and task != TOKENCLASSIFICATION
|
||||||
):
|
):
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
|
||||||
|
|||||||
13
flaml/ml.py
13
flaml/ml.py
@ -164,11 +164,21 @@ def metric_loss_score(
|
|||||||
score = metric.compute(predictions=y_predict, references=y_true)[
|
score = metric.compute(predictions=y_predict, references=y_true)[
|
||||||
metric_name
|
metric_name
|
||||||
].mid.fmeasure
|
].mid.fmeasure
|
||||||
|
elif metric_name == "seqeval":
|
||||||
|
y_true = [
|
||||||
|
[x for x in each_y_true if x != -100] for each_y_true in y_true
|
||||||
|
]
|
||||||
|
y_pred = [
|
||||||
|
y_predict[each_idx][: len(y_true[each_idx])]
|
||||||
|
for each_idx in range(len(y_predict))
|
||||||
|
]
|
||||||
|
score = metric.compute(predictions=y_pred, references=y_true)[
|
||||||
|
"overall_accuracy"
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
score = metric.compute(predictions=y_predict, references=y_true)[
|
score = metric.compute(predictions=y_predict, references=y_true)[
|
||||||
metric_name
|
metric_name
|
||||||
]
|
]
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
metric_name
|
metric_name
|
||||||
@ -226,6 +236,7 @@ def sklearn_metric_loss_score(
|
|||||||
Returns:
|
Returns:
|
||||||
score: A float number of the loss, the lower the better.
|
score: A float number of the loss, the lower the better.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metric_name = metric_name.lower()
|
metric_name = metric_name.lower()
|
||||||
|
|
||||||
if "r2" == metric_name:
|
if "r2" == metric_name:
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from .data import (
|
|||||||
TS_VALUE_COL,
|
TS_VALUE_COL,
|
||||||
SEQCLASSIFICATION,
|
SEQCLASSIFICATION,
|
||||||
SEQREGRESSION,
|
SEQREGRESSION,
|
||||||
|
TOKENCLASSIFICATION,
|
||||||
SUMMARIZATION,
|
SUMMARIZATION,
|
||||||
NLG_TASKS,
|
NLG_TASKS,
|
||||||
MULTICHOICECLASSIFICATION,
|
MULTICHOICECLASSIFICATION,
|
||||||
@ -310,7 +311,8 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _join(X_train, y_train):
|
def _join(X_train, y_train):
|
||||||
y_train = DataFrame(y_train, columns=["label"], index=X_train.index)
|
y_train = DataFrame(y_train, index=X_train.index)
|
||||||
|
y_train.columns = ["label"]
|
||||||
train_df = X_train.join(y_train)
|
train_df = X_train.join(y_train)
|
||||||
return train_df
|
return train_df
|
||||||
|
|
||||||
@ -370,17 +372,12 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
self.custom_hpo_args = custom_hpo_args
|
self.custom_hpo_args = custom_hpo_args
|
||||||
|
|
||||||
def _preprocess(self, X, y=None, **kwargs):
|
def _preprocess(self, X, y=None, **kwargs):
|
||||||
from .nlp.utils import tokenize_text
|
from .nlp.utils import tokenize_text, is_a_list_of_str
|
||||||
|
|
||||||
# is_str = False
|
|
||||||
# for each_type in ["string", "str"]:
|
|
||||||
# try:
|
|
||||||
# is_str = is_str or (X.dtypes[0] == each_type)
|
|
||||||
# except TypeError:
|
|
||||||
# pass
|
|
||||||
is_str = str(X.dtypes[0]) in ("string", "str")
|
is_str = str(X.dtypes[0]) in ("string", "str")
|
||||||
|
is_list_of_str = is_a_list_of_str(X[list(X.keys())[0]].to_list()[0])
|
||||||
|
|
||||||
if is_str:
|
if is_str or is_list_of_str:
|
||||||
return tokenize_text(
|
return tokenize_text(
|
||||||
X=X, Y=y, task=self._task, custom_hpo_args=self.custom_hpo_args
|
X=X, Y=y, task=self._task, custom_hpo_args=self.custom_hpo_args
|
||||||
)
|
)
|
||||||
@ -391,6 +388,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
from transformers import EarlyStoppingCallback
|
from transformers import EarlyStoppingCallback
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.data import DataCollatorWithPadding
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@ -455,7 +453,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
X_val = kwargs.get("X_val")
|
X_val = kwargs.get("X_val")
|
||||||
y_val = kwargs.get("y_val")
|
y_val = kwargs.get("y_val")
|
||||||
|
|
||||||
if self._task not in NLG_TASKS:
|
if (self._task not in NLG_TASKS) and (self._task != TOKENCLASSIFICATION):
|
||||||
self._X_train, _ = self._preprocess(X=X_train, **kwargs)
|
self._X_train, _ = self._preprocess(X=X_train, **kwargs)
|
||||||
self._y_train = y_train
|
self._y_train = y_train
|
||||||
else:
|
else:
|
||||||
@ -474,7 +472,7 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
# make sure they are the same
|
# make sure they are the same
|
||||||
|
|
||||||
if X_val is not None:
|
if X_val is not None:
|
||||||
if self._task not in NLG_TASKS:
|
if (self._task not in NLG_TASKS) and (self._task != TOKENCLASSIFICATION):
|
||||||
self._X_val, _ = self._preprocess(X=X_val, **kwargs)
|
self._X_val, _ = self._preprocess(X=X_val, **kwargs)
|
||||||
self._y_val = y_val
|
self._y_val = y_val
|
||||||
else:
|
else:
|
||||||
@ -648,6 +646,8 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
predictions = (
|
predictions = (
|
||||||
np.squeeze(predictions)
|
np.squeeze(predictions)
|
||||||
if self._task == SEQREGRESSION
|
if self._task == SEQREGRESSION
|
||||||
|
else np.argmax(predictions, axis=2)
|
||||||
|
if self._task == TOKENCLASSIFICATION
|
||||||
else np.argmax(predictions, axis=1)
|
else np.argmax(predictions, axis=1)
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
@ -724,7 +724,9 @@ class TransformersEstimator(BaseEstimator):
|
|||||||
if self._task == SEQCLASSIFICATION:
|
if self._task == SEQCLASSIFICATION:
|
||||||
return np.argmax(predictions.predictions, axis=1)
|
return np.argmax(predictions.predictions, axis=1)
|
||||||
elif self._task == SEQREGRESSION:
|
elif self._task == SEQREGRESSION:
|
||||||
return predictions.predictions
|
return predictions.predictions.reshape((len(predictions.predictions),))
|
||||||
|
elif self._task == TOKENCLASSIFICATION:
|
||||||
|
return np.argmax(predictions.predictions, axis=2)
|
||||||
# TODO: elif self._task == your task, return the corresponding prediction
|
# TODO: elif self._task == your task, return the corresponding prediction
|
||||||
# e.g., if your task == QUESTIONANSWERING, you need to return the answer instead
|
# e.g., if your task == QUESTIONANSWERING, you need to return the answer instead
|
||||||
# of the index
|
# of the index
|
||||||
|
|||||||
@ -5,9 +5,14 @@ import transformers
|
|||||||
if transformers.__version__.startswith("3"):
|
if transformers.__version__.startswith("3"):
|
||||||
from transformers.modeling_electra import ElectraClassificationHead
|
from transformers.modeling_electra import ElectraClassificationHead
|
||||||
from transformers.modeling_roberta import RobertaClassificationHead
|
from transformers.modeling_roberta import RobertaClassificationHead
|
||||||
|
from transformers.models.electra.modeling_electra import ElectraForTokenClassification
|
||||||
|
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from transformers.models.electra.modeling_electra import ElectraClassificationHead
|
from transformers.models.electra.modeling_electra import ElectraClassificationHead
|
||||||
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
|
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
|
||||||
|
from transformers.models.electra.modeling_electra import ElectraForTokenClassification
|
||||||
|
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
|
||||||
|
|
||||||
MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
|
MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
|||||||
@ -7,12 +7,14 @@ from ..data import (
|
|||||||
SUMMARIZATION,
|
SUMMARIZATION,
|
||||||
SEQREGRESSION,
|
SEQREGRESSION,
|
||||||
SEQCLASSIFICATION,
|
SEQCLASSIFICATION,
|
||||||
NLG_TASKS,
|
|
||||||
MULTICHOICECLASSIFICATION,
|
MULTICHOICECLASSIFICATION,
|
||||||
|
TOKENCLASSIFICATION,
|
||||||
|
NLG_TASKS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_default_huggingface_metric_for_task(task):
|
def load_default_huggingface_metric_for_task(task):
|
||||||
|
|
||||||
if task == SEQCLASSIFICATION:
|
if task == SEQCLASSIFICATION:
|
||||||
return "accuracy", "max"
|
return "accuracy", "max"
|
||||||
elif task == SEQREGRESSION:
|
elif task == SEQREGRESSION:
|
||||||
@ -20,15 +22,9 @@ def load_default_huggingface_metric_for_task(task):
|
|||||||
elif task == SUMMARIZATION:
|
elif task == SUMMARIZATION:
|
||||||
return "rouge", "max"
|
return "rouge", "max"
|
||||||
elif task == MULTICHOICECLASSIFICATION:
|
elif task == MULTICHOICECLASSIFICATION:
|
||||||
return "accuracy"
|
return "accuracy", "max"
|
||||||
# TODO: elif task == your task, return the default metric name for your task,
|
elif task == TOKENCLASSIFICATION:
|
||||||
# e.g., if task == MULTIPLECHOICE, return "accuracy"
|
return "seqeval", "max"
|
||||||
# notice this metric name has to be in ['accuracy', 'bertscore', 'bleu', 'bleurt',
|
|
||||||
# 'cer', 'chrf', 'code_eval', 'comet', 'competition_math', 'coval', 'cuad',
|
|
||||||
# 'f1', 'gleu', 'glue', 'google_bleu', 'indic_glue', 'matthews_correlation',
|
|
||||||
# 'meteor', 'pearsonr', 'precision', 'recall', 'rouge', 'sacrebleu', 'sari',
|
|
||||||
# 'seqeval', 'spearmanr', 'squad', 'squad_v2', 'super_glue', 'ter', 'wer',
|
|
||||||
# 'wiki_split', 'xnli']
|
|
||||||
|
|
||||||
|
|
||||||
global tokenized_column_names
|
global tokenized_column_names
|
||||||
@ -40,6 +36,8 @@ def tokenize_text(X, Y=None, task=None, custom_hpo_args=None):
|
|||||||
X, this_tokenizer=None, task=task, custom_hpo_args=custom_hpo_args
|
X, this_tokenizer=None, task=task, custom_hpo_args=custom_hpo_args
|
||||||
)
|
)
|
||||||
return X_tokenized, None
|
return X_tokenized, None
|
||||||
|
elif task == TOKENCLASSIFICATION:
|
||||||
|
return tokenize_text_tokclassification(X, Y, custom_hpo_args)
|
||||||
elif task in NLG_TASKS:
|
elif task in NLG_TASKS:
|
||||||
return tokenize_seq2seq(X, Y, task=task, custom_hpo_args=custom_hpo_args)
|
return tokenize_seq2seq(X, Y, task=task, custom_hpo_args=custom_hpo_args)
|
||||||
elif task == MULTICHOICECLASSIFICATION:
|
elif task == MULTICHOICECLASSIFICATION:
|
||||||
@ -71,6 +69,102 @@ def tokenize_seq2seq(X, Y, task=None, custom_hpo_args=None):
|
|||||||
return model_inputs, labels
|
return model_inputs, labels
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_and_align_labels(
|
||||||
|
examples, tokenizer, custom_hpo_args, X_sent_key, Y_sent_key=None
|
||||||
|
):
|
||||||
|
global tokenized_column_names
|
||||||
|
|
||||||
|
tokenized_inputs = tokenizer(
|
||||||
|
[list(examples[X_sent_key])],
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=custom_hpo_args.max_seq_length,
|
||||||
|
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
|
||||||
|
is_split_into_words=True,
|
||||||
|
)
|
||||||
|
if Y_sent_key is not None:
|
||||||
|
previous_word_idx = None
|
||||||
|
label_ids = []
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
for word_idx in tokenized_inputs.word_ids(batch_index=0):
|
||||||
|
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
||||||
|
# ignored in the loss function.
|
||||||
|
if word_idx is None:
|
||||||
|
label_ids.append(-100)
|
||||||
|
# We set the label for the first token of each word.
|
||||||
|
elif word_idx != previous_word_idx:
|
||||||
|
if isinstance(examples[Y_sent_key][word_idx], numbers.Number):
|
||||||
|
label_ids.append(examples[Y_sent_key][word_idx])
|
||||||
|
# else:
|
||||||
|
# label_ids.append(label_to_id[label[word_idx]])
|
||||||
|
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
||||||
|
# the label_all_tokens flag.
|
||||||
|
else:
|
||||||
|
if isinstance(examples[Y_sent_key][word_idx], numbers.Number):
|
||||||
|
label_ids.append(examples[Y_sent_key][word_idx])
|
||||||
|
# else:
|
||||||
|
# label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
|
||||||
|
previous_word_idx = word_idx
|
||||||
|
tokenized_inputs["label"] = label_ids
|
||||||
|
tokenized_column_names = sorted(tokenized_inputs.keys())
|
||||||
|
tokenized_input_and_labels = [tokenized_inputs[x] for x in tokenized_column_names]
|
||||||
|
for key_idx, each_key in enumerate(tokenized_column_names):
|
||||||
|
if each_key != "label":
|
||||||
|
tokenized_input_and_labels[key_idx] = tokenized_input_and_labels[key_idx][0]
|
||||||
|
return tokenized_input_and_labels
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text_tokclassification(X, Y, custom_hpo_args):
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
global tokenized_column_names
|
||||||
|
this_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
custom_hpo_args.model_path, use_fast=True
|
||||||
|
)
|
||||||
|
if Y is not None:
|
||||||
|
X_and_Y = pd.concat([X, Y.to_frame()], axis=1)
|
||||||
|
X_key = list(X.keys())[0]
|
||||||
|
Y_key = list(Y.to_frame().keys())[0]
|
||||||
|
X_and_Y_tokenized = X_and_Y.apply(
|
||||||
|
lambda x: tokenize_and_align_labels(
|
||||||
|
x,
|
||||||
|
tokenizer=this_tokenizer,
|
||||||
|
custom_hpo_args=custom_hpo_args,
|
||||||
|
X_sent_key=X_key,
|
||||||
|
Y_sent_key=Y_key,
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
result_type="expand",
|
||||||
|
)
|
||||||
|
label_idx = tokenized_column_names.index("label")
|
||||||
|
other_indices = sorted(
|
||||||
|
set(range(len(tokenized_column_names))).difference({label_idx})
|
||||||
|
)
|
||||||
|
other_column_names = [tokenized_column_names[x] for x in other_indices]
|
||||||
|
d = X_and_Y_tokenized.iloc[:, other_indices]
|
||||||
|
y_tokenized = X_and_Y_tokenized.iloc[:, label_idx]
|
||||||
|
else:
|
||||||
|
X_key = list(X.keys())[0]
|
||||||
|
d = X.apply(
|
||||||
|
lambda x: tokenize_and_align_labels(
|
||||||
|
x,
|
||||||
|
tokenizer=this_tokenizer,
|
||||||
|
custom_hpo_args=custom_hpo_args,
|
||||||
|
X_sent_key=X_key,
|
||||||
|
Y_sent_key=None,
|
||||||
|
),
|
||||||
|
axis=1,
|
||||||
|
result_type="expand",
|
||||||
|
)
|
||||||
|
other_column_names = tokenized_column_names
|
||||||
|
y_tokenized = None
|
||||||
|
X_tokenized = pd.DataFrame(columns=other_column_names)
|
||||||
|
X_tokenized[other_column_names] = d
|
||||||
|
return X_tokenized, y_tokenized
|
||||||
|
|
||||||
|
|
||||||
def tokenize_onedataframe(
|
def tokenize_onedataframe(
|
||||||
X,
|
X,
|
||||||
this_tokenizer=None,
|
this_tokenizer=None,
|
||||||
@ -229,16 +323,22 @@ def separate_config(config, task):
|
|||||||
|
|
||||||
|
|
||||||
def get_num_labels(task, y_train):
|
def get_num_labels(task, y_train):
|
||||||
from ..data import SEQCLASSIFICATION, SEQREGRESSION
|
from ..data import SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION
|
||||||
|
|
||||||
if task == SEQREGRESSION:
|
if task == SEQREGRESSION:
|
||||||
return 1
|
return 1
|
||||||
elif task == SEQCLASSIFICATION:
|
elif task == SEQCLASSIFICATION:
|
||||||
return len(set(y_train))
|
return len(set(y_train))
|
||||||
|
elif task == TOKENCLASSIFICATION:
|
||||||
|
return len(set([a for b in y_train.tolist() for a in b]))
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_a_list_of_str(this_obj):
|
||||||
|
return isinstance(this_obj, list) and all(isinstance(x, str) for x in this_obj)
|
||||||
|
|
||||||
|
|
||||||
def _clean_value(value: Any) -> str:
|
def _clean_value(value: Any) -> str:
|
||||||
if isinstance(value, float):
|
if isinstance(value, float):
|
||||||
return "{:.5}".format(value)
|
return "{:.5}".format(value)
|
||||||
@ -305,7 +405,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
|||||||
AutoSeqClassificationHead,
|
AutoSeqClassificationHead,
|
||||||
MODEL_CLASSIFICATION_HEAD_MAPPING,
|
MODEL_CLASSIFICATION_HEAD_MAPPING,
|
||||||
)
|
)
|
||||||
from ..data import SEQCLASSIFICATION, SEQREGRESSION
|
from ..data import SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION
|
||||||
|
|
||||||
this_model_type = AutoConfig.from_pretrained(checkpoint_path).model_type
|
this_model_type = AutoConfig.from_pretrained(checkpoint_path).model_type
|
||||||
this_vocab_size = AutoConfig.from_pretrained(checkpoint_path).vocab_size
|
this_vocab_size = AutoConfig.from_pretrained(checkpoint_path).vocab_size
|
||||||
@ -314,15 +414,16 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
|||||||
from transformers import AutoModelForSequenceClassification
|
from transformers import AutoModelForSequenceClassification
|
||||||
from transformers import AutoModelForSeq2SeqLM
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
from transformers import AutoModelForMultipleChoice
|
from transformers import AutoModelForMultipleChoice
|
||||||
|
from transformers import AutoModelForTokenClassification
|
||||||
|
|
||||||
if task in (SEQCLASSIFICATION, SEQREGRESSION):
|
if task in (SEQCLASSIFICATION, SEQREGRESSION):
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
checkpoint_path, config=model_config
|
checkpoint_path, config=model_config
|
||||||
)
|
)
|
||||||
# TODO: elif task == your task, fill in the line in your transformers example
|
elif task == TOKENCLASSIFICATION:
|
||||||
# that loads the model, e.g., if task == MULTIPLE CHOICE, according to
|
return AutoModelForTokenClassification.from_pretrained(
|
||||||
# https://github.com/huggingface/transformers/blob/master/examples/pytorch/multiple-choice/run_swag.py#L298
|
checkpoint_path, config=model_config
|
||||||
# you can return AutoModelForMultipleChoice.from_pretrained(checkpoint_path, config=model_config)
|
)
|
||||||
elif task in NLG_TASKS:
|
elif task in NLG_TASKS:
|
||||||
return AutoModelForSeq2SeqLM.from_pretrained(
|
return AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
checkpoint_path, config=model_config
|
checkpoint_path, config=model_config
|
||||||
@ -336,7 +437,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
|||||||
return model_type in MODEL_CLASSIFICATION_HEAD_MAPPING
|
return model_type in MODEL_CLASSIFICATION_HEAD_MAPPING
|
||||||
|
|
||||||
def _set_model_config(checkpoint_path):
|
def _set_model_config(checkpoint_path):
|
||||||
if task in (SEQCLASSIFICATION, SEQREGRESSION):
|
if task in (SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION):
|
||||||
if per_model_config:
|
if per_model_config:
|
||||||
model_config = AutoConfig.from_pretrained(
|
model_config = AutoConfig.from_pretrained(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
@ -385,6 +486,8 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
|
|||||||
else:
|
else:
|
||||||
if task == SEQREGRESSION:
|
if task == SEQREGRESSION:
|
||||||
model_config_num_labels = 1
|
model_config_num_labels = 1
|
||||||
|
elif task == TOKENCLASSIFICATION:
|
||||||
|
model_config_num_labels = num_labels
|
||||||
model_config = _set_model_config(checkpoint_path)
|
model_config = _set_model_config(checkpoint_path)
|
||||||
this_model = get_this_model(task)
|
this_model = get_this_model(task)
|
||||||
return this_model
|
return this_model
|
||||||
@ -411,7 +514,6 @@ def compute_checkpoint_freq(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class HPOArgs:
|
class HPOArgs:
|
||||||
"""The HPO setting.
|
"""The HPO setting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_dir (str): data root directory for outputing the log, etc.
|
output_dir (str): data root directory for outputing the log, etc.
|
||||||
model_path (str, optional, defaults to "facebook/muppet-roberta-base"): A string,
|
model_path (str, optional, defaults to "facebook/muppet-roberta-base"): A string,
|
||||||
@ -420,7 +522,6 @@ class HPOArgs:
|
|||||||
fp16 (bool, optional, defaults to "False"): A bool, whether to use FP16.
|
fp16 (bool, optional, defaults to "False"): A bool, whether to use FP16.
|
||||||
max_seq_length (int, optional, defaults to 128): An integer, the max length of the sequence.
|
max_seq_length (int, optional, defaults to 128): An integer, the max length of the sequence.
|
||||||
ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch.
|
ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -436,6 +537,15 @@ class HPOArgs:
|
|||||||
|
|
||||||
max_seq_length: int = field(default=128, metadata={"help": "max seq length"})
|
max_seq_length: int = field(default=128, metadata={"help": "max seq length"})
|
||||||
|
|
||||||
|
pad_to_max_length: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||||
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||||
|
"efficient on GPU but very bad for TPU."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
ckpt_per_epoch: int = field(default=1, metadata={"help": "checkpoint per epoch"})
|
ckpt_per_epoch: int = field(default=1, metadata={"help": "checkpoint per epoch"})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
3
setup.py
3
setup.py
@ -60,6 +60,7 @@ setuptools.setup(
|
|||||||
"torch",
|
"torch",
|
||||||
"nltk",
|
"nltk",
|
||||||
"rouge_score",
|
"rouge_score",
|
||||||
|
"seqeval",
|
||||||
],
|
],
|
||||||
"catboost": ["catboost>=0.26"],
|
"catboost": ["catboost>=0.26"],
|
||||||
"blendsearch": ["optuna==2.8.0"],
|
"blendsearch": ["optuna==2.8.0"],
|
||||||
@ -76,7 +77,7 @@ setuptools.setup(
|
|||||||
"vw": [
|
"vw": [
|
||||||
"vowpalwabbit",
|
"vowpalwabbit",
|
||||||
],
|
],
|
||||||
"nlp": ["transformers", "datasets", "torch", "nltk", "rouge_score"],
|
"nlp": ["transformers", "datasets", "torch", "seqeval", "nltk", "rouge_score"],
|
||||||
"ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
|
"ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
|
||||||
"forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
|
"forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
|
||||||
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
|
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
|
||||||
|
|||||||
@ -40,3 +40,7 @@ def test_cv():
|
|||||||
}
|
}
|
||||||
|
|
||||||
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_cv()
|
||||||
|
|||||||
@ -8,105 +8,176 @@ def test_mcc():
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
train_data = {'video-id': ['anetv_fruimvo90vA', 'anetv_fruimvo90vA', 'anetv_fruimvo90vA', 'anetv_MldEr60j33M', 'lsmdc0049_Hannah_and_her_sisters-69438'],
|
train_data = {
|
||||||
'fold-ind': ['10030', '10030', '10030', '5488', '17405'],
|
"video-id": [
|
||||||
'startphrase': ['A woman is seen running down a long track and jumping into a pit. The camera',
|
"anetv_fruimvo90vA",
|
||||||
'A woman is seen running down a long track and jumping into a pit. The camera',
|
"anetv_fruimvo90vA",
|
||||||
'A woman is seen running down a long track and jumping into a pit. The camera',
|
"anetv_fruimvo90vA",
|
||||||
'A man in a white shirt bends over and picks up a large weight. He',
|
"anetv_MldEr60j33M",
|
||||||
'Someone furiously shakes someone away. He'],
|
"lsmdc0049_Hannah_and_her_sisters-69438",
|
||||||
'sent1': ['A woman is seen running down a long track and jumping into a pit.',
|
],
|
||||||
'A woman is seen running down a long track and jumping into a pit.',
|
"fold-ind": ["10030", "10030", "10030", "5488", "17405"],
|
||||||
'A woman is seen running down a long track and jumping into a pit.',
|
"startphrase": [
|
||||||
'A man in a white shirt bends over and picks up a large weight.',
|
"A woman is seen running down a long track and jumping into a pit. The camera",
|
||||||
'Someone furiously shakes someone away.'],
|
"A woman is seen running down a long track and jumping into a pit. The camera",
|
||||||
'sent2': ['The camera', 'The camera', 'The camera', 'He', 'He'],
|
"A woman is seen running down a long track and jumping into a pit. The camera",
|
||||||
'gold-source': ['gen', 'gen', 'gold', 'gen', 'gold'],
|
"A man in a white shirt bends over and picks up a large weight. He",
|
||||||
'ending0': ['captures her as well as lifting weights down in place.',
|
"Someone furiously shakes someone away. He",
|
||||||
'follows her spinning her body around and ends by walking down a lane.',
|
],
|
||||||
'watches her as she walks away and sticks her tongue out to another person.',
|
"sent1": [
|
||||||
'lifts the weights over his head.',
|
"A woman is seen running down a long track and jumping into a pit.",
|
||||||
'runs to a woman standing waiting.'],
|
"A woman is seen running down a long track and jumping into a pit.",
|
||||||
'ending1': ['pans up to show another woman running down the track.',
|
"A woman is seen running down a long track and jumping into a pit.",
|
||||||
'pans around the two.',
|
"A man in a white shirt bends over and picks up a large weight.",
|
||||||
'captures her as well as lifting weights down in place.',
|
"Someone furiously shakes someone away.",
|
||||||
'also lifts it onto his chest before hanging it back out again.',
|
],
|
||||||
'tackles him into the passenger seat.'],
|
"sent2": ["The camera", "The camera", "The camera", "He", "He"],
|
||||||
'ending2': ['follows her movements as the group members follow her instructions.',
|
"gold-source": ["gen", "gen", "gold", "gen", "gold"],
|
||||||
'captures her as well as lifting weights down in place.',
|
"ending0": [
|
||||||
'follows her spinning her body around and ends by walking down a lane.',
|
"captures her as well as lifting weights down in place.",
|
||||||
'spins around and lifts a barbell onto the floor.',
|
"follows her spinning her body around and ends by walking down a lane.",
|
||||||
'pounds his fist against a cupboard.'],
|
"watches her as she walks away and sticks her tongue out to another person.",
|
||||||
'ending3': ['follows her spinning her body around and ends by walking down a lane.',
|
"lifts the weights over his head.",
|
||||||
'follows her movements as the group members follow her instructions.',
|
"runs to a woman standing waiting.",
|
||||||
'pans around the two.',
|
],
|
||||||
'bends down and lifts the weight over his head.',
|
"ending1": [
|
||||||
'offers someone the cup on his elbow and strides out.'],
|
"pans up to show another woman running down the track.",
|
||||||
'label': [1, 3, 0, 0, 2]}
|
"pans around the two.",
|
||||||
dev_data = {'video-id': ['lsmdc3001_21_JUMP_STREET-422',
|
"captures her as well as lifting weights down in place.",
|
||||||
'lsmdc0001_American_Beauty-45991',
|
"also lifts it onto his chest before hanging it back out again.",
|
||||||
'lsmdc0001_American_Beauty-45991',
|
"tackles him into the passenger seat.",
|
||||||
'lsmdc0001_American_Beauty-45991'],
|
],
|
||||||
'fold-ind': ['11783', '10977', '10970', '10968'],
|
"ending2": [
|
||||||
'startphrase': ['Firing wildly he shoots holes through the tanker. He',
|
"follows her movements as the group members follow her instructions.",
|
||||||
'He puts his spatula down. The Mercedes',
|
"captures her as well as lifting weights down in place.",
|
||||||
'He stands and looks around, his eyes finally landing on: The digicam and a stack of cassettes on a shelf. Someone',
|
"follows her spinning her body around and ends by walking down a lane.",
|
||||||
"He starts going through someone's bureau. He opens the drawer in which we know someone keeps his marijuana, but he"],
|
"spins around and lifts a barbell onto the floor.",
|
||||||
'sent1': ['Firing wildly he shoots holes through the tanker.',
|
"pounds his fist against a cupboard.",
|
||||||
'He puts his spatula down.',
|
],
|
||||||
'He stands and looks around, his eyes finally landing on: The digicam and a stack of cassettes on a shelf.',
|
"ending3": [
|
||||||
"He starts going through someone's bureau."],
|
"follows her spinning her body around and ends by walking down a lane.",
|
||||||
'sent2': ['He', 'The Mercedes', 'Someone', 'He opens the drawer in which we know someone keeps his marijuana, but he'],
|
"follows her movements as the group members follow her instructions.",
|
||||||
'gold-source': ['gold', 'gold', 'gold', 'gold'],
|
"pans around the two.",
|
||||||
'ending0': ['overtakes the rig and falls off his bike.',
|
"bends down and lifts the weight over his head.",
|
||||||
'fly open and drinks.',
|
"offers someone the cup on his elbow and strides out.",
|
||||||
|
],
|
||||||
|
"label": [1, 3, 0, 0, 2],
|
||||||
|
}
|
||||||
|
dev_data = {
|
||||||
|
"video-id": [
|
||||||
|
"lsmdc3001_21_JUMP_STREET-422",
|
||||||
|
"lsmdc0001_American_Beauty-45991",
|
||||||
|
"lsmdc0001_American_Beauty-45991",
|
||||||
|
"lsmdc0001_American_Beauty-45991",
|
||||||
|
],
|
||||||
|
"fold-ind": ["11783", "10977", "10970", "10968"],
|
||||||
|
"startphrase": [
|
||||||
|
"Firing wildly he shoots holes through the tanker. He",
|
||||||
|
"He puts his spatula down. The Mercedes",
|
||||||
|
"He stands and looks around, his eyes finally landing on: "
|
||||||
|
"The digicam and a stack of cassettes on a shelf. Someone",
|
||||||
|
"He starts going through someone's bureau. He opens the drawer "
|
||||||
|
"in which we know someone keeps his marijuana, but he",
|
||||||
|
],
|
||||||
|
"sent1": [
|
||||||
|
"Firing wildly he shoots holes through the tanker.",
|
||||||
|
"He puts his spatula down.",
|
||||||
|
"He stands and looks around, his eyes finally landing on: "
|
||||||
|
"The digicam and a stack of cassettes on a shelf.",
|
||||||
|
"He starts going through someone's bureau.",
|
||||||
|
],
|
||||||
|
"sent2": [
|
||||||
|
"He",
|
||||||
|
"The Mercedes",
|
||||||
|
"Someone",
|
||||||
|
"He opens the drawer in which we know someone keeps his marijuana, but he",
|
||||||
|
],
|
||||||
|
"gold-source": ["gold", "gold", "gold", "gold"],
|
||||||
|
"ending0": [
|
||||||
|
"overtakes the rig and falls off his bike.",
|
||||||
|
"fly open and drinks.",
|
||||||
"looks at someone's papers.",
|
"looks at someone's papers.",
|
||||||
'stops one down and rubs a piece of the gift out.'],
|
"stops one down and rubs a piece of the gift out.",
|
||||||
'ending1': ['squeezes relentlessly on the peanut jelly as well.',
|
],
|
||||||
'walks off followed driveway again.',
|
"ending1": [
|
||||||
'feels around it and falls in the seat once more.',
|
"squeezes relentlessly on the peanut jelly as well.",
|
||||||
'cuts the mangled parts.'],
|
"walks off followed driveway again.",
|
||||||
'ending2': ['scrambles behind himself and comes in other directions.',
|
"feels around it and falls in the seat once more.",
|
||||||
'slots them into a separate green.',
|
"cuts the mangled parts.",
|
||||||
'sprints back from the wreck and drops onto his back.',
|
],
|
||||||
'hides it under his hat to watch.'],
|
"ending2": [
|
||||||
'ending3': ['sweeps a explodes and knocks someone off.',
|
"scrambles behind himself and comes in other directions.",
|
||||||
'pulls around to the drive - thru window.',
|
"slots them into a separate green.",
|
||||||
'sits at the kitchen table, staring off into space.',
|
"sprints back from the wreck and drops onto his back.",
|
||||||
"does n't discover its false bottom."],
|
"hides it under his hat to watch.",
|
||||||
'label': [0, 3, 3, 3]}
|
],
|
||||||
test_data = {'video-id': ['lsmdc0001_American_Beauty-45991',
|
"ending3": [
|
||||||
'lsmdc0001_American_Beauty-45991',
|
"sweeps a explodes and knocks someone off.",
|
||||||
'lsmdc0001_American_Beauty-45991',
|
"pulls around to the drive - thru window.",
|
||||||
'lsmdc0001_American_Beauty-45991'],
|
"sits at the kitchen table, staring off into space.",
|
||||||
'fold-ind': ['10980', '10976', '10978', '10969'],
|
"does n't discover its false bottom.",
|
||||||
'startphrase': ['Someone leans out of the drive - thru window, grinning at her, holding bags filled with fast food. The Counter Girl',
|
],
|
||||||
'Someone looks up suddenly when he hears. He',
|
"label": [0, 3, 3, 3],
|
||||||
'Someone drives; someone sits beside her. They',
|
}
|
||||||
"He opens the drawer in which we know someone keeps his marijuana, but he does n't discover its false bottom. He stands and looks around, his eyes"],
|
test_data = {
|
||||||
'sent1': ['Someone leans out of the drive - thru window, grinning at her, holding bags filled with fast food.',
|
"video-id": [
|
||||||
'Someone looks up suddenly when he hears.',
|
"lsmdc0001_American_Beauty-45991",
|
||||||
'Someone drives; someone sits beside her.',
|
"lsmdc0001_American_Beauty-45991",
|
||||||
"He opens the drawer in which we know someone keeps his marijuana, but he does n't discover its false bottom."],
|
"lsmdc0001_American_Beauty-45991",
|
||||||
'sent2': ['The Counter Girl', 'He', 'They', 'He stands and looks around, his eyes'],
|
"lsmdc0001_American_Beauty-45991",
|
||||||
'gold-source': ['gold', 'gold', 'gold', 'gold'],
|
],
|
||||||
'ending0': ['stands next to him, staring blankly.',
|
"fold-ind": ["10980", "10976", "10978", "10969"],
|
||||||
'puts his spatula down.',
|
"startphrase": [
|
||||||
|
"Someone leans out of the drive - thru window, "
|
||||||
|
"grinning at her, holding bags filled with fast food. The Counter Girl",
|
||||||
|
"Someone looks up suddenly when he hears. He",
|
||||||
|
"Someone drives; someone sits beside her. They",
|
||||||
|
"He opens the drawer in which we know someone "
|
||||||
|
"keeps his marijuana, but he does n't discover"
|
||||||
|
" its false bottom. He stands and looks around, his eyes",
|
||||||
|
],
|
||||||
|
"sent1": [
|
||||||
|
"Someone leans out of the drive - thru "
|
||||||
|
"window, grinning at her, holding bags filled with fast food.",
|
||||||
|
"Someone looks up suddenly when he hears.",
|
||||||
|
"Someone drives; someone sits beside her.",
|
||||||
|
"He opens the drawer in which we know"
|
||||||
|
" someone keeps his marijuana, but he does n't discover its false bottom.",
|
||||||
|
],
|
||||||
|
"sent2": [
|
||||||
|
"The Counter Girl",
|
||||||
|
"He",
|
||||||
|
"They",
|
||||||
|
"He stands and looks around, his eyes",
|
||||||
|
],
|
||||||
|
"gold-source": ["gold", "gold", "gold", "gold"],
|
||||||
|
"ending0": [
|
||||||
|
"stands next to him, staring blankly.",
|
||||||
|
"puts his spatula down.",
|
||||||
"rise someone's feet up.",
|
"rise someone's feet up.",
|
||||||
'moving to the side, the houses rapidly stained.'],
|
"moving to the side, the houses rapidly stained.",
|
||||||
'ending1': ['with auditorium, filmed, singers the club.',
|
],
|
||||||
'bumps into a revolver and drops surreptitiously into his weapon.',
|
"ending1": [
|
||||||
'lift her and they are alarmed.',
|
"with auditorium, filmed, singers the club.",
|
||||||
'focused as the sight of someone making his way down a trail.'],
|
"bumps into a revolver and drops surreptitiously into his weapon.",
|
||||||
'ending2': ['attempts to block her ransacked.',
|
"lift her and they are alarmed.",
|
||||||
'talks using the phone and walks away for a few seconds.',
|
"focused as the sight of someone making his way down a trail.",
|
||||||
'are too involved with each other to notice someone watching them from the drive - thru window.',
|
],
|
||||||
'finally landing on: the digicam and a stack of cassettes on a shelf.'],
|
"ending2": [
|
||||||
'ending3': ['is eating solid and stinky.',
|
"attempts to block her ransacked.",
|
||||||
'bundles the flaxen powder beneath the car.',
|
"talks using the phone and walks away for a few seconds.",
|
||||||
'sit at a table with a beer from a table.',
|
"are too involved with each other to "
|
||||||
"deep and continuing, its bleed - length sideburns pressing on him."],
|
"notice someone watching them from the drive - thru window.",
|
||||||
'label': [0, 0, 2, 2]}
|
"finally landing on: the digicam and a stack of cassettes on a shelf.",
|
||||||
|
],
|
||||||
|
"ending3": [
|
||||||
|
"is eating solid and stinky.",
|
||||||
|
"bundles the flaxen powder beneath the car.",
|
||||||
|
"sit at a table with a beer from a table.",
|
||||||
|
"deep and continuing, its bleed - length sideburns pressing on him.",
|
||||||
|
],
|
||||||
|
"label": [0, 0, 2, 2],
|
||||||
|
}
|
||||||
|
|
||||||
train_dataset = pd.DataFrame(train_data)
|
train_dataset = pd.DataFrame(train_data)
|
||||||
dev_dataset = pd.DataFrame(dev_data)
|
dev_dataset = pd.DataFrame(dev_data)
|
||||||
|
|||||||
741
test/nlp/test_autohf_tokenclassification.py
Normal file
741
test/nlp/test_autohf_tokenclassification.py
Normal file
@ -0,0 +1,741 @@
|
|||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
|
||||||
|
def test_tokenclassification():
|
||||||
|
from flaml import AutoML
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
train_data = {
|
||||||
|
"chunk_tags": [
|
||||||
|
[11, 21, 11, 12, 21, 22, 11, 12, 0],
|
||||||
|
[11, 12],
|
||||||
|
[11, 12],
|
||||||
|
[
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
17,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
17,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
"id": ["0", "1", "2", "3"],
|
||||||
|
"ner_tags": [
|
||||||
|
[3, 0, 7, 0, 0, 0, 7, 0, 0],
|
||||||
|
[1, 2],
|
||||||
|
[5, 0],
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
7,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
7,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
"pos_tags": [
|
||||||
|
[22, 42, 16, 21, 35, 37, 16, 21, 7],
|
||||||
|
[22, 22],
|
||||||
|
[22, 11],
|
||||||
|
[
|
||||||
|
12,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
38,
|
||||||
|
15,
|
||||||
|
22,
|
||||||
|
28,
|
||||||
|
38,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
35,
|
||||||
|
24,
|
||||||
|
35,
|
||||||
|
37,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
15,
|
||||||
|
24,
|
||||||
|
41,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
21,
|
||||||
|
20,
|
||||||
|
37,
|
||||||
|
40,
|
||||||
|
35,
|
||||||
|
21,
|
||||||
|
7,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
"tokens": [
|
||||||
|
[
|
||||||
|
"EU",
|
||||||
|
"rejects",
|
||||||
|
"German",
|
||||||
|
"call",
|
||||||
|
"to",
|
||||||
|
"boycott",
|
||||||
|
"British",
|
||||||
|
"lamb",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
["Peter", "Blackburn"],
|
||||||
|
["BRUSSELS", "1996-08-22"],
|
||||||
|
[
|
||||||
|
"The",
|
||||||
|
"European",
|
||||||
|
"Commission",
|
||||||
|
"said",
|
||||||
|
"on",
|
||||||
|
"Thursday",
|
||||||
|
"it",
|
||||||
|
"disagreed",
|
||||||
|
"with",
|
||||||
|
"German",
|
||||||
|
"advice",
|
||||||
|
"to",
|
||||||
|
"consumers",
|
||||||
|
"to",
|
||||||
|
"shun",
|
||||||
|
"British",
|
||||||
|
"lamb",
|
||||||
|
"until",
|
||||||
|
"scientists",
|
||||||
|
"determine",
|
||||||
|
"whether",
|
||||||
|
"mad",
|
||||||
|
"cow",
|
||||||
|
"disease",
|
||||||
|
"can",
|
||||||
|
"be",
|
||||||
|
"transmitted",
|
||||||
|
"to",
|
||||||
|
"sheep",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
dev_data = {
|
||||||
|
"chunk_tags": [
|
||||||
|
[
|
||||||
|
11,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
1,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
17,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
17,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
0,
|
||||||
|
17,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
17,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
11,
|
||||||
|
0,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
11,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
"id": ["4", "5", "6", "7"],
|
||||||
|
"ner_tags": [
|
||||||
|
[
|
||||||
|
5,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
5,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0],
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
"pos_tags": [
|
||||||
|
[
|
||||||
|
22,
|
||||||
|
27,
|
||||||
|
21,
|
||||||
|
35,
|
||||||
|
12,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
27,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
38,
|
||||||
|
15,
|
||||||
|
22,
|
||||||
|
24,
|
||||||
|
20,
|
||||||
|
37,
|
||||||
|
21,
|
||||||
|
15,
|
||||||
|
24,
|
||||||
|
16,
|
||||||
|
15,
|
||||||
|
22,
|
||||||
|
15,
|
||||||
|
12,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
38,
|
||||||
|
17,
|
||||||
|
7,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
28,
|
||||||
|
41,
|
||||||
|
30,
|
||||||
|
37,
|
||||||
|
12,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
15,
|
||||||
|
28,
|
||||||
|
41,
|
||||||
|
30,
|
||||||
|
37,
|
||||||
|
12,
|
||||||
|
24,
|
||||||
|
15,
|
||||||
|
28,
|
||||||
|
6,
|
||||||
|
0,
|
||||||
|
12,
|
||||||
|
22,
|
||||||
|
27,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
14,
|
||||||
|
22,
|
||||||
|
38,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
21,
|
||||||
|
7,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
28,
|
||||||
|
38,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
38,
|
||||||
|
40,
|
||||||
|
10,
|
||||||
|
15,
|
||||||
|
28,
|
||||||
|
38,
|
||||||
|
40,
|
||||||
|
15,
|
||||||
|
21,
|
||||||
|
38,
|
||||||
|
40,
|
||||||
|
28,
|
||||||
|
20,
|
||||||
|
37,
|
||||||
|
40,
|
||||||
|
15,
|
||||||
|
12,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
7,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
28,
|
||||||
|
38,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
15,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
22,
|
||||||
|
35,
|
||||||
|
37,
|
||||||
|
21,
|
||||||
|
24,
|
||||||
|
6,
|
||||||
|
24,
|
||||||
|
10,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
15,
|
||||||
|
12,
|
||||||
|
21,
|
||||||
|
10,
|
||||||
|
21,
|
||||||
|
21,
|
||||||
|
24,
|
||||||
|
38,
|
||||||
|
12,
|
||||||
|
30,
|
||||||
|
16,
|
||||||
|
10,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
35,
|
||||||
|
37,
|
||||||
|
16,
|
||||||
|
21,
|
||||||
|
7,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
"tokens": [
|
||||||
|
[
|
||||||
|
"Germany",
|
||||||
|
"'s",
|
||||||
|
"representative",
|
||||||
|
"to",
|
||||||
|
"the",
|
||||||
|
"European",
|
||||||
|
"Union",
|
||||||
|
"'s",
|
||||||
|
"veterinary",
|
||||||
|
"committee",
|
||||||
|
"Werner",
|
||||||
|
"Zwingmann",
|
||||||
|
"said",
|
||||||
|
"on",
|
||||||
|
"Wednesday",
|
||||||
|
"consumers",
|
||||||
|
"should",
|
||||||
|
"buy",
|
||||||
|
"sheepmeat",
|
||||||
|
"from",
|
||||||
|
"countries",
|
||||||
|
"other",
|
||||||
|
"than",
|
||||||
|
"Britain",
|
||||||
|
"until",
|
||||||
|
"the",
|
||||||
|
"scientific",
|
||||||
|
"advice",
|
||||||
|
"was",
|
||||||
|
"clearer",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'"',
|
||||||
|
"We",
|
||||||
|
"do",
|
||||||
|
"n't",
|
||||||
|
"support",
|
||||||
|
"any",
|
||||||
|
"such",
|
||||||
|
"recommendation",
|
||||||
|
"because",
|
||||||
|
"we",
|
||||||
|
"do",
|
||||||
|
"n't",
|
||||||
|
"see",
|
||||||
|
"any",
|
||||||
|
"grounds",
|
||||||
|
"for",
|
||||||
|
"it",
|
||||||
|
",",
|
||||||
|
'"',
|
||||||
|
"the",
|
||||||
|
"Commission",
|
||||||
|
"'s",
|
||||||
|
"chief",
|
||||||
|
"spokesman",
|
||||||
|
"Nikolaus",
|
||||||
|
"van",
|
||||||
|
"der",
|
||||||
|
"Pas",
|
||||||
|
"told",
|
||||||
|
"a",
|
||||||
|
"news",
|
||||||
|
"briefing",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"He",
|
||||||
|
"said",
|
||||||
|
"further",
|
||||||
|
"scientific",
|
||||||
|
"study",
|
||||||
|
"was",
|
||||||
|
"required",
|
||||||
|
"and",
|
||||||
|
"if",
|
||||||
|
"it",
|
||||||
|
"was",
|
||||||
|
"found",
|
||||||
|
"that",
|
||||||
|
"action",
|
||||||
|
"was",
|
||||||
|
"needed",
|
||||||
|
"it",
|
||||||
|
"should",
|
||||||
|
"be",
|
||||||
|
"taken",
|
||||||
|
"by",
|
||||||
|
"the",
|
||||||
|
"European",
|
||||||
|
"Union",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"He",
|
||||||
|
"said",
|
||||||
|
"a",
|
||||||
|
"proposal",
|
||||||
|
"last",
|
||||||
|
"month",
|
||||||
|
"by",
|
||||||
|
"EU",
|
||||||
|
"Farm",
|
||||||
|
"Commissioner",
|
||||||
|
"Franz",
|
||||||
|
"Fischler",
|
||||||
|
"to",
|
||||||
|
"ban",
|
||||||
|
"sheep",
|
||||||
|
"brains",
|
||||||
|
",",
|
||||||
|
"spleens",
|
||||||
|
"and",
|
||||||
|
"spinal",
|
||||||
|
"cords",
|
||||||
|
"from",
|
||||||
|
"the",
|
||||||
|
"human",
|
||||||
|
"and",
|
||||||
|
"animal",
|
||||||
|
"food",
|
||||||
|
"chains",
|
||||||
|
"was",
|
||||||
|
"a",
|
||||||
|
"highly",
|
||||||
|
"specific",
|
||||||
|
"and",
|
||||||
|
"precautionary",
|
||||||
|
"move",
|
||||||
|
"to",
|
||||||
|
"protect",
|
||||||
|
"human",
|
||||||
|
"health",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = pd.DataFrame(train_data)
|
||||||
|
dev_dataset = pd.DataFrame(dev_data)
|
||||||
|
|
||||||
|
custom_sent_keys = ["tokens"]
|
||||||
|
label_key = "ner_tags"
|
||||||
|
|
||||||
|
X_train = train_dataset[custom_sent_keys]
|
||||||
|
y_train = train_dataset[label_key]
|
||||||
|
|
||||||
|
X_val = dev_dataset[custom_sent_keys]
|
||||||
|
y_val = dev_dataset[label_key]
|
||||||
|
|
||||||
|
automl = AutoML()
|
||||||
|
|
||||||
|
automl_settings = {
|
||||||
|
"gpu_per_trial": 0,
|
||||||
|
"max_iter": 2,
|
||||||
|
"time_budget": 5,
|
||||||
|
"task": "token-classification",
|
||||||
|
"metric": "seqeval",
|
||||||
|
}
|
||||||
|
|
||||||
|
automl_settings["custom_hpo_args"] = {
|
||||||
|
"model_path": "bert-base-uncased",
|
||||||
|
"output_dir": "test/data/output/",
|
||||||
|
"ckpt_per_epoch": 5,
|
||||||
|
"fp16": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
automl.fit(
|
||||||
|
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_tokenclassification()
|
||||||
Loading…
x
Reference in New Issue
Block a user