mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-11-03 19:29:52 +00:00 
			
		
		
		
	
							parent
							
								
									8602def1c4
								
							
						
					
					
						commit
						207b6935d9
					
				@ -40,6 +40,7 @@ from .config import (
 | 
			
		||||
from .data import (
 | 
			
		||||
    concat,
 | 
			
		||||
    CLASSIFICATION,
 | 
			
		||||
    TOKENCLASSIFICATION,
 | 
			
		||||
    TS_FORECAST,
 | 
			
		||||
    FORECAST,
 | 
			
		||||
    REGRESSION,
 | 
			
		||||
@ -866,6 +867,8 @@ class AutoML(BaseEstimator):
 | 
			
		||||
 | 
			
		||||
        # check the validity of input dimensions under the nlp mode
 | 
			
		||||
        if _is_nlp_task(self._state.task):
 | 
			
		||||
            from .nlp.utils import is_a_list_of_str
 | 
			
		||||
 | 
			
		||||
            is_all_str = True
 | 
			
		||||
            is_all_list = True
 | 
			
		||||
            for column in X.columns:
 | 
			
		||||
@ -874,17 +877,25 @@ class AutoML(BaseEstimator):
 | 
			
		||||
                    "string",
 | 
			
		||||
                ), "If the task is an NLP task, X can only contain text columns"
 | 
			
		||||
                for each_cell in X[column]:
 | 
			
		||||
                    if each_cell:
 | 
			
		||||
                    if each_cell is not None:
 | 
			
		||||
                        is_str = isinstance(each_cell, str)
 | 
			
		||||
                        is_list_of_int = isinstance(each_cell, list) and all(
 | 
			
		||||
                            isinstance(x, int) for x in each_cell
 | 
			
		||||
                        )
 | 
			
		||||
                        assert is_str or is_list_of_int, (
 | 
			
		||||
                            "Each column of the input must either be str (untokenized) "
 | 
			
		||||
                            "or a list of integers (tokenized)"
 | 
			
		||||
                        )
 | 
			
		||||
                        is_list_of_str = is_a_list_of_str(each_cell)
 | 
			
		||||
                        if self._state.task == TOKENCLASSIFICATION:
 | 
			
		||||
                            assert is_list_of_str, (
 | 
			
		||||
                                "For the token-classification task, the input column needs to be a list of string,"
 | 
			
		||||
                                "instead of string, e.g., ['EU', 'rejects','German', 'call','to','boycott','British','lamb','.',].",
 | 
			
		||||
                                "For more examples, please refer to test/nlp/test_autohf_tokenclassification.py",
 | 
			
		||||
                            )
 | 
			
		||||
                        else:
 | 
			
		||||
                            assert is_str or is_list_of_int, (
 | 
			
		||||
                                "Each column of the input must either be str (untokenized) "
 | 
			
		||||
                                "or a list of integers (tokenized)"
 | 
			
		||||
                            )
 | 
			
		||||
                        is_all_str &= is_str
 | 
			
		||||
                        is_all_list &= is_list_of_int
 | 
			
		||||
                        is_all_list &= is_list_of_int or is_list_of_str
 | 
			
		||||
            assert is_all_str or is_all_list, (
 | 
			
		||||
                "Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
 | 
			
		||||
                "or all columns of X are integer ids (tokenized)"
 | 
			
		||||
@ -963,6 +974,7 @@ class AutoML(BaseEstimator):
 | 
			
		||||
            and self._auto_augment
 | 
			
		||||
            and self._state.fit_kwargs.get("sample_weight") is None
 | 
			
		||||
            and self._split_type in ["stratified", "uniform"]
 | 
			
		||||
            and self._state.task != TOKENCLASSIFICATION
 | 
			
		||||
        ):
 | 
			
		||||
            # logger.info(f"label {pd.unique(y_train_all)}")
 | 
			
		||||
            label_set, counts = np.unique(y_train_all, return_counts=True)
 | 
			
		||||
 | 
			
		||||
@ -15,12 +15,14 @@ from typing import Dict, Union, List
 | 
			
		||||
# TODO: if your task is not specified in here, define your task as an all-capitalized word
 | 
			
		||||
SEQCLASSIFICATION = "seq-classification"
 | 
			
		||||
MULTICHOICECLASSIFICATION = "multichoice-classification"
 | 
			
		||||
TOKENCLASSIFICATION = "token-classification"
 | 
			
		||||
CLASSIFICATION = (
 | 
			
		||||
    "binary",
 | 
			
		||||
    "multi",
 | 
			
		||||
    "classification",
 | 
			
		||||
    SEQCLASSIFICATION,
 | 
			
		||||
    MULTICHOICECLASSIFICATION,
 | 
			
		||||
    TOKENCLASSIFICATION,
 | 
			
		||||
)
 | 
			
		||||
SEQREGRESSION = "seq-regression"
 | 
			
		||||
REGRESSION = ("regression", SEQREGRESSION)
 | 
			
		||||
@ -34,6 +36,7 @@ NLU_TASKS = (
 | 
			
		||||
    SEQREGRESSION,
 | 
			
		||||
    SEQCLASSIFICATION,
 | 
			
		||||
    MULTICHOICECLASSIFICATION,
 | 
			
		||||
    TOKENCLASSIFICATION,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -354,11 +357,10 @@ class DataTransformer:
 | 
			
		||||
                datetime_columns,
 | 
			
		||||
            )
 | 
			
		||||
            self._drop = drop
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            task in CLASSIFICATION
 | 
			
		||||
            or not pd.api.types.is_numeric_dtype(y)
 | 
			
		||||
            (task in CLASSIFICATION or not pd.api.types.is_numeric_dtype(y))
 | 
			
		||||
            and task not in NLG_TASKS
 | 
			
		||||
            and task != TOKENCLASSIFICATION
 | 
			
		||||
        ):
 | 
			
		||||
            from sklearn.preprocessing import LabelEncoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								flaml/ml.py
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								flaml/ml.py
									
									
									
									
									
								
							@ -164,11 +164,21 @@ def metric_loss_score(
 | 
			
		||||
                    score = metric.compute(predictions=y_predict, references=y_true)[
 | 
			
		||||
                        metric_name
 | 
			
		||||
                    ].mid.fmeasure
 | 
			
		||||
                elif metric_name == "seqeval":
 | 
			
		||||
                    y_true = [
 | 
			
		||||
                        [x for x in each_y_true if x != -100] for each_y_true in y_true
 | 
			
		||||
                    ]
 | 
			
		||||
                    y_pred = [
 | 
			
		||||
                        y_predict[each_idx][: len(y_true[each_idx])]
 | 
			
		||||
                        for each_idx in range(len(y_predict))
 | 
			
		||||
                    ]
 | 
			
		||||
                    score = metric.compute(predictions=y_pred, references=y_true)[
 | 
			
		||||
                        "overall_accuracy"
 | 
			
		||||
                    ]
 | 
			
		||||
                else:
 | 
			
		||||
                    score = metric.compute(predictions=y_predict, references=y_true)[
 | 
			
		||||
                        metric_name
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                raise Exception(
 | 
			
		||||
                    metric_name
 | 
			
		||||
@ -226,6 +236,7 @@ def sklearn_metric_loss_score(
 | 
			
		||||
    Returns:
 | 
			
		||||
        score: A float number of the loss, the lower the better.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    metric_name = metric_name.lower()
 | 
			
		||||
 | 
			
		||||
    if "r2" == metric_name:
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ from .data import (
 | 
			
		||||
    TS_VALUE_COL,
 | 
			
		||||
    SEQCLASSIFICATION,
 | 
			
		||||
    SEQREGRESSION,
 | 
			
		||||
    TOKENCLASSIFICATION,
 | 
			
		||||
    SUMMARIZATION,
 | 
			
		||||
    NLG_TASKS,
 | 
			
		||||
    MULTICHOICECLASSIFICATION,
 | 
			
		||||
@ -310,7 +311,8 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _join(X_train, y_train):
 | 
			
		||||
        y_train = DataFrame(y_train, columns=["label"], index=X_train.index)
 | 
			
		||||
        y_train = DataFrame(y_train, index=X_train.index)
 | 
			
		||||
        y_train.columns = ["label"]
 | 
			
		||||
        train_df = X_train.join(y_train)
 | 
			
		||||
        return train_df
 | 
			
		||||
 | 
			
		||||
@ -370,17 +372,12 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
        self.custom_hpo_args = custom_hpo_args
 | 
			
		||||
 | 
			
		||||
    def _preprocess(self, X, y=None, **kwargs):
 | 
			
		||||
        from .nlp.utils import tokenize_text
 | 
			
		||||
        from .nlp.utils import tokenize_text, is_a_list_of_str
 | 
			
		||||
 | 
			
		||||
        # is_str = False
 | 
			
		||||
        # for each_type in ["string", "str"]:
 | 
			
		||||
        #     try:
 | 
			
		||||
        #         is_str = is_str or (X.dtypes[0] == each_type)
 | 
			
		||||
        #     except TypeError:
 | 
			
		||||
        #         pass
 | 
			
		||||
        is_str = str(X.dtypes[0]) in ("string", "str")
 | 
			
		||||
        is_list_of_str = is_a_list_of_str(X[list(X.keys())[0]].to_list()[0])
 | 
			
		||||
 | 
			
		||||
        if is_str:
 | 
			
		||||
        if is_str or is_list_of_str:
 | 
			
		||||
            return tokenize_text(
 | 
			
		||||
                X=X, Y=y, task=self._task, custom_hpo_args=self.custom_hpo_args
 | 
			
		||||
            )
 | 
			
		||||
@ -391,6 +388,7 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
        from transformers import EarlyStoppingCallback
 | 
			
		||||
        from transformers.trainer_utils import set_seed
 | 
			
		||||
        from transformers import AutoTokenizer
 | 
			
		||||
        from transformers.data import DataCollatorWithPadding
 | 
			
		||||
 | 
			
		||||
        import transformers
 | 
			
		||||
        from datasets import Dataset
 | 
			
		||||
@ -455,7 +453,7 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
        X_val = kwargs.get("X_val")
 | 
			
		||||
        y_val = kwargs.get("y_val")
 | 
			
		||||
 | 
			
		||||
        if self._task not in NLG_TASKS:
 | 
			
		||||
        if (self._task not in NLG_TASKS) and (self._task != TOKENCLASSIFICATION):
 | 
			
		||||
            self._X_train, _ = self._preprocess(X=X_train, **kwargs)
 | 
			
		||||
            self._y_train = y_train
 | 
			
		||||
        else:
 | 
			
		||||
@ -474,7 +472,7 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
        #  make sure they are the same
 | 
			
		||||
 | 
			
		||||
        if X_val is not None:
 | 
			
		||||
            if self._task not in NLG_TASKS:
 | 
			
		||||
            if (self._task not in NLG_TASKS) and (self._task != TOKENCLASSIFICATION):
 | 
			
		||||
                self._X_val, _ = self._preprocess(X=X_val, **kwargs)
 | 
			
		||||
                self._y_val = y_val
 | 
			
		||||
            else:
 | 
			
		||||
@ -648,6 +646,8 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
                predictions = (
 | 
			
		||||
                    np.squeeze(predictions)
 | 
			
		||||
                    if self._task == SEQREGRESSION
 | 
			
		||||
                    else np.argmax(predictions, axis=2)
 | 
			
		||||
                    if self._task == TOKENCLASSIFICATION
 | 
			
		||||
                    else np.argmax(predictions, axis=1)
 | 
			
		||||
                )
 | 
			
		||||
            return {
 | 
			
		||||
@ -724,7 +724,9 @@ class TransformersEstimator(BaseEstimator):
 | 
			
		||||
        if self._task == SEQCLASSIFICATION:
 | 
			
		||||
            return np.argmax(predictions.predictions, axis=1)
 | 
			
		||||
        elif self._task == SEQREGRESSION:
 | 
			
		||||
            return predictions.predictions
 | 
			
		||||
            return predictions.predictions.reshape((len(predictions.predictions),))
 | 
			
		||||
        elif self._task == TOKENCLASSIFICATION:
 | 
			
		||||
            return np.argmax(predictions.predictions, axis=2)
 | 
			
		||||
        # TODO: elif self._task == your task, return the corresponding prediction
 | 
			
		||||
        #  e.g., if your task == QUESTIONANSWERING, you need to return the answer instead
 | 
			
		||||
        #  of the index
 | 
			
		||||
 | 
			
		||||
@ -5,9 +5,14 @@ import transformers
 | 
			
		||||
if transformers.__version__.startswith("3"):
 | 
			
		||||
    from transformers.modeling_electra import ElectraClassificationHead
 | 
			
		||||
    from transformers.modeling_roberta import RobertaClassificationHead
 | 
			
		||||
    from transformers.models.electra.modeling_electra import ElectraForTokenClassification
 | 
			
		||||
    from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
 | 
			
		||||
 | 
			
		||||
else:
 | 
			
		||||
    from transformers.models.electra.modeling_electra import ElectraClassificationHead
 | 
			
		||||
    from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
 | 
			
		||||
    from transformers.models.electra.modeling_electra import ElectraForTokenClassification
 | 
			
		||||
    from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
 | 
			
		||||
 | 
			
		||||
MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
 | 
			
		||||
    [
 | 
			
		||||
 | 
			
		||||
@ -7,12 +7,14 @@ from ..data import (
 | 
			
		||||
    SUMMARIZATION,
 | 
			
		||||
    SEQREGRESSION,
 | 
			
		||||
    SEQCLASSIFICATION,
 | 
			
		||||
    NLG_TASKS,
 | 
			
		||||
    MULTICHOICECLASSIFICATION,
 | 
			
		||||
    TOKENCLASSIFICATION,
 | 
			
		||||
    NLG_TASKS,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_default_huggingface_metric_for_task(task):
 | 
			
		||||
 | 
			
		||||
    if task == SEQCLASSIFICATION:
 | 
			
		||||
        return "accuracy", "max"
 | 
			
		||||
    elif task == SEQREGRESSION:
 | 
			
		||||
@ -20,15 +22,9 @@ def load_default_huggingface_metric_for_task(task):
 | 
			
		||||
    elif task == SUMMARIZATION:
 | 
			
		||||
        return "rouge", "max"
 | 
			
		||||
    elif task == MULTICHOICECLASSIFICATION:
 | 
			
		||||
        return "accuracy"
 | 
			
		||||
    # TODO: elif task == your task, return the default metric name for your task,
 | 
			
		||||
    #  e.g., if task == MULTIPLECHOICE, return "accuracy"
 | 
			
		||||
    #  notice this metric name has to be in ['accuracy', 'bertscore', 'bleu', 'bleurt',
 | 
			
		||||
    #  'cer', 'chrf', 'code_eval', 'comet', 'competition_math', 'coval', 'cuad',
 | 
			
		||||
    #  'f1', 'gleu', 'glue', 'google_bleu', 'indic_glue', 'matthews_correlation',
 | 
			
		||||
    #  'meteor', 'pearsonr', 'precision', 'recall', 'rouge', 'sacrebleu', 'sari',
 | 
			
		||||
    #  'seqeval', 'spearmanr', 'squad', 'squad_v2', 'super_glue', 'ter', 'wer',
 | 
			
		||||
    #  'wiki_split', 'xnli']
 | 
			
		||||
        return "accuracy", "max"
 | 
			
		||||
    elif task == TOKENCLASSIFICATION:
 | 
			
		||||
        return "seqeval", "max"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
global tokenized_column_names
 | 
			
		||||
@ -40,6 +36,8 @@ def tokenize_text(X, Y=None, task=None, custom_hpo_args=None):
 | 
			
		||||
            X, this_tokenizer=None, task=task, custom_hpo_args=custom_hpo_args
 | 
			
		||||
        )
 | 
			
		||||
        return X_tokenized, None
 | 
			
		||||
    elif task == TOKENCLASSIFICATION:
 | 
			
		||||
        return tokenize_text_tokclassification(X, Y, custom_hpo_args)
 | 
			
		||||
    elif task in NLG_TASKS:
 | 
			
		||||
        return tokenize_seq2seq(X, Y, task=task, custom_hpo_args=custom_hpo_args)
 | 
			
		||||
    elif task == MULTICHOICECLASSIFICATION:
 | 
			
		||||
@ -71,11 +69,107 @@ def tokenize_seq2seq(X, Y, task=None, custom_hpo_args=None):
 | 
			
		||||
    return model_inputs, labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tokenize_and_align_labels(
 | 
			
		||||
    examples, tokenizer, custom_hpo_args, X_sent_key, Y_sent_key=None
 | 
			
		||||
):
 | 
			
		||||
    global tokenized_column_names
 | 
			
		||||
 | 
			
		||||
    tokenized_inputs = tokenizer(
 | 
			
		||||
        [list(examples[X_sent_key])],
 | 
			
		||||
        padding="max_length",
 | 
			
		||||
        truncation=True,
 | 
			
		||||
        max_length=custom_hpo_args.max_seq_length,
 | 
			
		||||
        # We use this argument because the texts in our dataset are lists of words (with a label for each word).
 | 
			
		||||
        is_split_into_words=True,
 | 
			
		||||
    )
 | 
			
		||||
    if Y_sent_key is not None:
 | 
			
		||||
        previous_word_idx = None
 | 
			
		||||
        label_ids = []
 | 
			
		||||
        import numbers
 | 
			
		||||
 | 
			
		||||
        for word_idx in tokenized_inputs.word_ids(batch_index=0):
 | 
			
		||||
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
 | 
			
		||||
            # ignored in the loss function.
 | 
			
		||||
            if word_idx is None:
 | 
			
		||||
                label_ids.append(-100)
 | 
			
		||||
            # We set the label for the first token of each word.
 | 
			
		||||
            elif word_idx != previous_word_idx:
 | 
			
		||||
                if isinstance(examples[Y_sent_key][word_idx], numbers.Number):
 | 
			
		||||
                    label_ids.append(examples[Y_sent_key][word_idx])
 | 
			
		||||
                # else:
 | 
			
		||||
                #     label_ids.append(label_to_id[label[word_idx]])
 | 
			
		||||
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
 | 
			
		||||
            # the label_all_tokens flag.
 | 
			
		||||
            else:
 | 
			
		||||
                if isinstance(examples[Y_sent_key][word_idx], numbers.Number):
 | 
			
		||||
                    label_ids.append(examples[Y_sent_key][word_idx])
 | 
			
		||||
                # else:
 | 
			
		||||
                #     label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
 | 
			
		||||
            previous_word_idx = word_idx
 | 
			
		||||
        tokenized_inputs["label"] = label_ids
 | 
			
		||||
    tokenized_column_names = sorted(tokenized_inputs.keys())
 | 
			
		||||
    tokenized_input_and_labels = [tokenized_inputs[x] for x in tokenized_column_names]
 | 
			
		||||
    for key_idx, each_key in enumerate(tokenized_column_names):
 | 
			
		||||
        if each_key != "label":
 | 
			
		||||
            tokenized_input_and_labels[key_idx] = tokenized_input_and_labels[key_idx][0]
 | 
			
		||||
    return tokenized_input_and_labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tokenize_text_tokclassification(X, Y, custom_hpo_args):
 | 
			
		||||
    from transformers import AutoTokenizer
 | 
			
		||||
    import pandas as pd
 | 
			
		||||
 | 
			
		||||
    global tokenized_column_names
 | 
			
		||||
    this_tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        custom_hpo_args.model_path, use_fast=True
 | 
			
		||||
    )
 | 
			
		||||
    if Y is not None:
 | 
			
		||||
        X_and_Y = pd.concat([X, Y.to_frame()], axis=1)
 | 
			
		||||
        X_key = list(X.keys())[0]
 | 
			
		||||
        Y_key = list(Y.to_frame().keys())[0]
 | 
			
		||||
        X_and_Y_tokenized = X_and_Y.apply(
 | 
			
		||||
            lambda x: tokenize_and_align_labels(
 | 
			
		||||
                x,
 | 
			
		||||
                tokenizer=this_tokenizer,
 | 
			
		||||
                custom_hpo_args=custom_hpo_args,
 | 
			
		||||
                X_sent_key=X_key,
 | 
			
		||||
                Y_sent_key=Y_key,
 | 
			
		||||
            ),
 | 
			
		||||
            axis=1,
 | 
			
		||||
            result_type="expand",
 | 
			
		||||
        )
 | 
			
		||||
        label_idx = tokenized_column_names.index("label")
 | 
			
		||||
        other_indices = sorted(
 | 
			
		||||
            set(range(len(tokenized_column_names))).difference({label_idx})
 | 
			
		||||
        )
 | 
			
		||||
        other_column_names = [tokenized_column_names[x] for x in other_indices]
 | 
			
		||||
        d = X_and_Y_tokenized.iloc[:, other_indices]
 | 
			
		||||
        y_tokenized = X_and_Y_tokenized.iloc[:, label_idx]
 | 
			
		||||
    else:
 | 
			
		||||
        X_key = list(X.keys())[0]
 | 
			
		||||
        d = X.apply(
 | 
			
		||||
            lambda x: tokenize_and_align_labels(
 | 
			
		||||
                x,
 | 
			
		||||
                tokenizer=this_tokenizer,
 | 
			
		||||
                custom_hpo_args=custom_hpo_args,
 | 
			
		||||
                X_sent_key=X_key,
 | 
			
		||||
                Y_sent_key=None,
 | 
			
		||||
            ),
 | 
			
		||||
            axis=1,
 | 
			
		||||
            result_type="expand",
 | 
			
		||||
        )
 | 
			
		||||
        other_column_names = tokenized_column_names
 | 
			
		||||
        y_tokenized = None
 | 
			
		||||
    X_tokenized = pd.DataFrame(columns=other_column_names)
 | 
			
		||||
    X_tokenized[other_column_names] = d
 | 
			
		||||
    return X_tokenized, y_tokenized
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tokenize_onedataframe(
 | 
			
		||||
        X,
 | 
			
		||||
        this_tokenizer=None,
 | 
			
		||||
        task=None,
 | 
			
		||||
        custom_hpo_args=None,
 | 
			
		||||
    X,
 | 
			
		||||
    this_tokenizer=None,
 | 
			
		||||
    task=None,
 | 
			
		||||
    custom_hpo_args=None,
 | 
			
		||||
):
 | 
			
		||||
    from transformers import AutoTokenizer
 | 
			
		||||
    import pandas
 | 
			
		||||
@ -130,11 +224,11 @@ def postprocess_text(preds, labels):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tokenize_row(
 | 
			
		||||
        this_row, this_tokenizer, prefix=None, task=None, custom_hpo_args=None
 | 
			
		||||
    this_row, this_tokenizer, prefix=None, task=None, custom_hpo_args=None
 | 
			
		||||
):
 | 
			
		||||
    global tokenized_column_names
 | 
			
		||||
    assert (
 | 
			
		||||
            "max_seq_length" in custom_hpo_args.__dict__
 | 
			
		||||
        "max_seq_length" in custom_hpo_args.__dict__
 | 
			
		||||
    ), "max_seq_length must be provided for glue"
 | 
			
		||||
 | 
			
		||||
    if prefix:
 | 
			
		||||
@ -229,16 +323,22 @@ def separate_config(config, task):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_num_labels(task, y_train):
 | 
			
		||||
    from ..data import SEQCLASSIFICATION, SEQREGRESSION
 | 
			
		||||
    from ..data import SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION
 | 
			
		||||
 | 
			
		||||
    if task == SEQREGRESSION:
 | 
			
		||||
        return 1
 | 
			
		||||
    elif task == SEQCLASSIFICATION:
 | 
			
		||||
        return len(set(y_train))
 | 
			
		||||
    elif task == TOKENCLASSIFICATION:
 | 
			
		||||
        return len(set([a for b in y_train.tolist() for a in b]))
 | 
			
		||||
    else:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_a_list_of_str(this_obj):
 | 
			
		||||
    return isinstance(this_obj, list) and all(isinstance(x, str) for x in this_obj)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _clean_value(value: Any) -> str:
 | 
			
		||||
    if isinstance(value, float):
 | 
			
		||||
        return "{:.5}".format(value)
 | 
			
		||||
@ -305,7 +405,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
 | 
			
		||||
        AutoSeqClassificationHead,
 | 
			
		||||
        MODEL_CLASSIFICATION_HEAD_MAPPING,
 | 
			
		||||
    )
 | 
			
		||||
    from ..data import SEQCLASSIFICATION, SEQREGRESSION
 | 
			
		||||
    from ..data import SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION
 | 
			
		||||
 | 
			
		||||
    this_model_type = AutoConfig.from_pretrained(checkpoint_path).model_type
 | 
			
		||||
    this_vocab_size = AutoConfig.from_pretrained(checkpoint_path).vocab_size
 | 
			
		||||
@ -314,15 +414,16 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
 | 
			
		||||
        from transformers import AutoModelForSequenceClassification
 | 
			
		||||
        from transformers import AutoModelForSeq2SeqLM
 | 
			
		||||
        from transformers import AutoModelForMultipleChoice
 | 
			
		||||
        from transformers import AutoModelForTokenClassification
 | 
			
		||||
 | 
			
		||||
        if task in (SEQCLASSIFICATION, SEQREGRESSION):
 | 
			
		||||
            return AutoModelForSequenceClassification.from_pretrained(
 | 
			
		||||
                checkpoint_path, config=model_config
 | 
			
		||||
            )
 | 
			
		||||
        # TODO: elif task == your task, fill in the line in your transformers example
 | 
			
		||||
        #  that loads the model, e.g., if task == MULTIPLE CHOICE, according to
 | 
			
		||||
        #  https://github.com/huggingface/transformers/blob/master/examples/pytorch/multiple-choice/run_swag.py#L298
 | 
			
		||||
        #  you can return AutoModelForMultipleChoice.from_pretrained(checkpoint_path, config=model_config)
 | 
			
		||||
        elif task == TOKENCLASSIFICATION:
 | 
			
		||||
            return AutoModelForTokenClassification.from_pretrained(
 | 
			
		||||
                checkpoint_path, config=model_config
 | 
			
		||||
            )
 | 
			
		||||
        elif task in NLG_TASKS:
 | 
			
		||||
            return AutoModelForSeq2SeqLM.from_pretrained(
 | 
			
		||||
                checkpoint_path, config=model_config
 | 
			
		||||
@ -336,7 +437,7 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
 | 
			
		||||
        return model_type in MODEL_CLASSIFICATION_HEAD_MAPPING
 | 
			
		||||
 | 
			
		||||
    def _set_model_config(checkpoint_path):
 | 
			
		||||
        if task in (SEQCLASSIFICATION, SEQREGRESSION):
 | 
			
		||||
        if task in (SEQCLASSIFICATION, SEQREGRESSION, TOKENCLASSIFICATION):
 | 
			
		||||
            if per_model_config:
 | 
			
		||||
                model_config = AutoConfig.from_pretrained(
 | 
			
		||||
                    checkpoint_path,
 | 
			
		||||
@ -385,25 +486,27 @@ def load_model(checkpoint_path, task, num_labels, per_model_config=None):
 | 
			
		||||
    else:
 | 
			
		||||
        if task == SEQREGRESSION:
 | 
			
		||||
            model_config_num_labels = 1
 | 
			
		||||
        elif task == TOKENCLASSIFICATION:
 | 
			
		||||
            model_config_num_labels = num_labels
 | 
			
		||||
        model_config = _set_model_config(checkpoint_path)
 | 
			
		||||
        this_model = get_this_model(task)
 | 
			
		||||
        return this_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_checkpoint_freq(
 | 
			
		||||
        train_data_size,
 | 
			
		||||
        custom_hpo_args,
 | 
			
		||||
        num_train_epochs,
 | 
			
		||||
        batch_size,
 | 
			
		||||
    train_data_size,
 | 
			
		||||
    custom_hpo_args,
 | 
			
		||||
    num_train_epochs,
 | 
			
		||||
    batch_size,
 | 
			
		||||
):
 | 
			
		||||
    ckpt_step_freq = (
 | 
			
		||||
            int(
 | 
			
		||||
                min(num_train_epochs, 1)
 | 
			
		||||
                * train_data_size
 | 
			
		||||
                / batch_size
 | 
			
		||||
                / custom_hpo_args.ckpt_per_epoch
 | 
			
		||||
            )
 | 
			
		||||
            + 1
 | 
			
		||||
        int(
 | 
			
		||||
            min(num_train_epochs, 1)
 | 
			
		||||
            * train_data_size
 | 
			
		||||
            / batch_size
 | 
			
		||||
            / custom_hpo_args.ckpt_per_epoch
 | 
			
		||||
        )
 | 
			
		||||
        + 1
 | 
			
		||||
    )
 | 
			
		||||
    return ckpt_step_freq
 | 
			
		||||
 | 
			
		||||
@ -411,7 +514,6 @@ def compute_checkpoint_freq(
 | 
			
		||||
@dataclass
 | 
			
		||||
class HPOArgs:
 | 
			
		||||
    """The HPO setting.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        output_dir (str): data root directory for outputing the log, etc.
 | 
			
		||||
        model_path (str, optional, defaults to "facebook/muppet-roberta-base"): A string,
 | 
			
		||||
@ -420,7 +522,6 @@ class HPOArgs:
 | 
			
		||||
        fp16 (bool, optional, defaults to "False"): A bool, whether to use FP16.
 | 
			
		||||
        max_seq_length (int, optional, defaults to 128): An integer, the max length of the sequence.
 | 
			
		||||
        ckpt_per_epoch (int, optional, defaults to 1): An integer, the number of checkpoints per epoch.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    output_dir: str = field(
 | 
			
		||||
@ -436,6 +537,15 @@ class HPOArgs:
 | 
			
		||||
 | 
			
		||||
    max_seq_length: int = field(default=128, metadata={"help": "max seq length"})
 | 
			
		||||
 | 
			
		||||
    pad_to_max_length: bool = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "Whether to pad all samples to model maximum sentence length. "
 | 
			
		||||
            "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
 | 
			
		||||
            "efficient on GPU but very bad for TPU."
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    ckpt_per_epoch: int = field(default=1, metadata={"help": "checkpoint per epoch"})
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							@ -60,6 +60,7 @@ setuptools.setup(
 | 
			
		||||
            "torch",
 | 
			
		||||
            "nltk",
 | 
			
		||||
            "rouge_score",
 | 
			
		||||
            "seqeval",
 | 
			
		||||
        ],
 | 
			
		||||
        "catboost": ["catboost>=0.26"],
 | 
			
		||||
        "blendsearch": ["optuna==2.8.0"],
 | 
			
		||||
@ -76,7 +77,7 @@ setuptools.setup(
 | 
			
		||||
        "vw": [
 | 
			
		||||
            "vowpalwabbit",
 | 
			
		||||
        ],
 | 
			
		||||
        "nlp": ["transformers", "datasets", "torch", "nltk", "rouge_score"],
 | 
			
		||||
        "nlp": ["transformers", "datasets", "torch", "seqeval", "nltk", "rouge_score"],
 | 
			
		||||
        "ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
 | 
			
		||||
        "forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
 | 
			
		||||
        "benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
 | 
			
		||||
 | 
			
		||||
@ -40,3 +40,7 @@ def test_cv():
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    test_cv()
 | 
			
		||||
 | 
			
		||||
@ -8,105 +8,176 @@ def test_mcc():
 | 
			
		||||
 | 
			
		||||
    import pandas as pd
 | 
			
		||||
 | 
			
		||||
    train_data = {'video-id': ['anetv_fruimvo90vA', 'anetv_fruimvo90vA', 'anetv_fruimvo90vA', 'anetv_MldEr60j33M', 'lsmdc0049_Hannah_and_her_sisters-69438'],
 | 
			
		||||
            'fold-ind': ['10030', '10030', '10030', '5488', '17405'],
 | 
			
		||||
            'startphrase': ['A woman is seen running down a long track and jumping into a pit. The camera',
 | 
			
		||||
                            'A woman is seen running down a long track and jumping into a pit. The camera',
 | 
			
		||||
                            'A woman is seen running down a long track and jumping into a pit. The camera',
 | 
			
		||||
                            'A man in a white shirt bends over and picks up a large weight. He',
 | 
			
		||||
                            'Someone furiously shakes someone away. He'],
 | 
			
		||||
            'sent1': ['A woman is seen running down a long track and jumping into a pit.',
 | 
			
		||||
                      'A woman is seen running down a long track and jumping into a pit.',
 | 
			
		||||
                      'A woman is seen running down a long track and jumping into a pit.',
 | 
			
		||||
                      'A man in a white shirt bends over and picks up a large weight.',
 | 
			
		||||
                      'Someone furiously shakes someone away.'],
 | 
			
		||||
            'sent2': ['The camera', 'The camera', 'The camera', 'He', 'He'],
 | 
			
		||||
            'gold-source': ['gen', 'gen', 'gold', 'gen', 'gold'],
 | 
			
		||||
            'ending0': ['captures her as well as lifting weights down in place.',
 | 
			
		||||
                        'follows her spinning her body around and ends by walking down a lane.',
 | 
			
		||||
                        'watches her as she walks away and sticks her tongue out to another person.',
 | 
			
		||||
                        'lifts the weights over his head.',
 | 
			
		||||
                        'runs to a woman standing waiting.'],
 | 
			
		||||
            'ending1': ['pans up to show another woman running down the track.',
 | 
			
		||||
                        'pans around the two.',
 | 
			
		||||
                        'captures her as well as lifting weights down in place.',
 | 
			
		||||
                        'also lifts it onto his chest before hanging it back out again.',
 | 
			
		||||
                        'tackles him into the passenger seat.'],
 | 
			
		||||
            'ending2': ['follows her movements as the group members follow her instructions.',
 | 
			
		||||
                        'captures her as well as lifting weights down in place.',
 | 
			
		||||
                        'follows her spinning her body around and ends by walking down a lane.',
 | 
			
		||||
                        'spins around and lifts a barbell onto the floor.',
 | 
			
		||||
                        'pounds his fist against a cupboard.'],
 | 
			
		||||
            'ending3': ['follows her spinning her body around and ends by walking down a lane.',
 | 
			
		||||
                        'follows her movements as the group members follow her instructions.',
 | 
			
		||||
                        'pans around the two.',
 | 
			
		||||
                        'bends down and lifts the weight over his head.',
 | 
			
		||||
                        'offers someone the cup on his elbow and strides out.'],
 | 
			
		||||
            'label': [1, 3, 0, 0, 2]}
 | 
			
		||||
    dev_data = {'video-id': ['lsmdc3001_21_JUMP_STREET-422',
 | 
			
		||||
                             'lsmdc0001_American_Beauty-45991',
 | 
			
		||||
                             'lsmdc0001_American_Beauty-45991',
 | 
			
		||||
                             'lsmdc0001_American_Beauty-45991'],
 | 
			
		||||
            'fold-ind': ['11783', '10977', '10970', '10968'],
 | 
			
		||||
            'startphrase': ['Firing wildly he shoots holes through the tanker. He',
 | 
			
		||||
                            'He puts his spatula down. The Mercedes',
 | 
			
		||||
                            'He stands and looks around, his eyes finally landing on: The digicam and a stack of cassettes on a shelf. Someone',
 | 
			
		||||
                            "He starts going through someone's bureau. He opens the drawer in which we know someone keeps his marijuana, but he"],
 | 
			
		||||
            'sent1': ['Firing wildly he shoots holes through the tanker.',
 | 
			
		||||
                      'He puts his spatula down.',
 | 
			
		||||
                      'He stands and looks around, his eyes finally landing on: The digicam and a stack of cassettes on a shelf.',
 | 
			
		||||
                      "He starts going through someone's bureau."],
 | 
			
		||||
            'sent2': ['He', 'The Mercedes', 'Someone', 'He opens the drawer in which we know someone keeps his marijuana, but he'],
 | 
			
		||||
            'gold-source': ['gold', 'gold', 'gold', 'gold'],
 | 
			
		||||
            'ending0': ['overtakes the rig and falls off his bike.',
 | 
			
		||||
                        'fly open and drinks.',
 | 
			
		||||
                        "looks at someone's papers.",
 | 
			
		||||
                        'stops one down and rubs a piece of the gift out.'],
 | 
			
		||||
            'ending1': ['squeezes relentlessly on the peanut jelly as well.',
 | 
			
		||||
                        'walks off followed driveway again.',
 | 
			
		||||
                        'feels around it and falls in the seat once more.',
 | 
			
		||||
                        'cuts the mangled parts.'],
 | 
			
		||||
            'ending2': ['scrambles behind himself and comes in other directions.',
 | 
			
		||||
                        'slots them into a separate green.',
 | 
			
		||||
                        'sprints back from the wreck and drops onto his back.',
 | 
			
		||||
                        'hides it under his hat to watch.'],
 | 
			
		||||
            'ending3': ['sweeps a explodes and knocks someone off.',
 | 
			
		||||
                        'pulls around to the drive - thru window.',
 | 
			
		||||
                        'sits at the kitchen table, staring off into space.',
 | 
			
		||||
                        "does n't discover its false bottom."],
 | 
			
		||||
            'label': [0, 3, 3, 3]}
 | 
			
		||||
    test_data = {'video-id': ['lsmdc0001_American_Beauty-45991',
 | 
			
		||||
                             'lsmdc0001_American_Beauty-45991',
 | 
			
		||||
                             'lsmdc0001_American_Beauty-45991',
 | 
			
		||||
                             'lsmdc0001_American_Beauty-45991'],
 | 
			
		||||
            'fold-ind': ['10980', '10976', '10978', '10969'],
 | 
			
		||||
            'startphrase': ['Someone leans out of the drive - thru window, grinning at her, holding bags filled with fast food. The Counter Girl',
 | 
			
		||||
                            'Someone looks up suddenly when he hears. He',
 | 
			
		||||
                            'Someone drives; someone sits beside her. They',
 | 
			
		||||
                            "He opens the drawer in which we know someone keeps his marijuana, but he does n't discover its false bottom. He stands and looks around, his eyes"],
 | 
			
		||||
            'sent1': ['Someone leans out of the drive - thru window, grinning at her, holding bags filled with fast food.',
 | 
			
		||||
                      'Someone looks up suddenly when he hears.',
 | 
			
		||||
                      'Someone drives; someone sits beside her.',
 | 
			
		||||
                      "He opens the drawer in which we know someone keeps his marijuana, but he does n't discover its false bottom."],
 | 
			
		||||
            'sent2': ['The Counter Girl', 'He', 'They', 'He stands and looks around, his eyes'],
 | 
			
		||||
            'gold-source': ['gold', 'gold', 'gold', 'gold'],
 | 
			
		||||
            'ending0': ['stands next to him, staring blankly.',
 | 
			
		||||
                        'puts his spatula down.',
 | 
			
		||||
                        "rise someone's feet up.",
 | 
			
		||||
                        'moving to the side, the houses rapidly stained.'],
 | 
			
		||||
            'ending1': ['with auditorium, filmed, singers the club.',
 | 
			
		||||
                        'bumps into a revolver and drops surreptitiously into his weapon.',
 | 
			
		||||
                        'lift her and they are alarmed.',
 | 
			
		||||
                        'focused as the sight of someone making his way down a trail.'],
 | 
			
		||||
            'ending2': ['attempts to block her ransacked.',
 | 
			
		||||
                        'talks using the phone and walks away for a few seconds.',
 | 
			
		||||
                        'are too involved with each other to notice someone watching them from the drive - thru window.',
 | 
			
		||||
                        'finally landing on: the digicam and a stack of cassettes on a shelf.'],
 | 
			
		||||
            'ending3': ['is eating solid and stinky.',
 | 
			
		||||
                        'bundles the flaxen powder beneath the car.',
 | 
			
		||||
                        'sit at a table with a beer from a table.',
 | 
			
		||||
                        "deep and continuing, its bleed - length sideburns pressing on him."],
 | 
			
		||||
            'label': [0, 0, 2, 2]}
 | 
			
		||||
    train_data = {
 | 
			
		||||
        "video-id": [
 | 
			
		||||
            "anetv_fruimvo90vA",
 | 
			
		||||
            "anetv_fruimvo90vA",
 | 
			
		||||
            "anetv_fruimvo90vA",
 | 
			
		||||
            "anetv_MldEr60j33M",
 | 
			
		||||
            "lsmdc0049_Hannah_and_her_sisters-69438",
 | 
			
		||||
        ],
 | 
			
		||||
        "fold-ind": ["10030", "10030", "10030", "5488", "17405"],
 | 
			
		||||
        "startphrase": [
 | 
			
		||||
            "A woman is seen running down a long track and jumping into a pit. The camera",
 | 
			
		||||
            "A woman is seen running down a long track and jumping into a pit. The camera",
 | 
			
		||||
            "A woman is seen running down a long track and jumping into a pit. The camera",
 | 
			
		||||
            "A man in a white shirt bends over and picks up a large weight. He",
 | 
			
		||||
            "Someone furiously shakes someone away. He",
 | 
			
		||||
        ],
 | 
			
		||||
        "sent1": [
 | 
			
		||||
            "A woman is seen running down a long track and jumping into a pit.",
 | 
			
		||||
            "A woman is seen running down a long track and jumping into a pit.",
 | 
			
		||||
            "A woman is seen running down a long track and jumping into a pit.",
 | 
			
		||||
            "A man in a white shirt bends over and picks up a large weight.",
 | 
			
		||||
            "Someone furiously shakes someone away.",
 | 
			
		||||
        ],
 | 
			
		||||
        "sent2": ["The camera", "The camera", "The camera", "He", "He"],
 | 
			
		||||
        "gold-source": ["gen", "gen", "gold", "gen", "gold"],
 | 
			
		||||
        "ending0": [
 | 
			
		||||
            "captures her as well as lifting weights down in place.",
 | 
			
		||||
            "follows her spinning her body around and ends by walking down a lane.",
 | 
			
		||||
            "watches her as she walks away and sticks her tongue out to another person.",
 | 
			
		||||
            "lifts the weights over his head.",
 | 
			
		||||
            "runs to a woman standing waiting.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending1": [
 | 
			
		||||
            "pans up to show another woman running down the track.",
 | 
			
		||||
            "pans around the two.",
 | 
			
		||||
            "captures her as well as lifting weights down in place.",
 | 
			
		||||
            "also lifts it onto his chest before hanging it back out again.",
 | 
			
		||||
            "tackles him into the passenger seat.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending2": [
 | 
			
		||||
            "follows her movements as the group members follow her instructions.",
 | 
			
		||||
            "captures her as well as lifting weights down in place.",
 | 
			
		||||
            "follows her spinning her body around and ends by walking down a lane.",
 | 
			
		||||
            "spins around and lifts a barbell onto the floor.",
 | 
			
		||||
            "pounds his fist against a cupboard.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending3": [
 | 
			
		||||
            "follows her spinning her body around and ends by walking down a lane.",
 | 
			
		||||
            "follows her movements as the group members follow her instructions.",
 | 
			
		||||
            "pans around the two.",
 | 
			
		||||
            "bends down and lifts the weight over his head.",
 | 
			
		||||
            "offers someone the cup on his elbow and strides out.",
 | 
			
		||||
        ],
 | 
			
		||||
        "label": [1, 3, 0, 0, 2],
 | 
			
		||||
    }
 | 
			
		||||
    dev_data = {
 | 
			
		||||
        "video-id": [
 | 
			
		||||
            "lsmdc3001_21_JUMP_STREET-422",
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
        ],
 | 
			
		||||
        "fold-ind": ["11783", "10977", "10970", "10968"],
 | 
			
		||||
        "startphrase": [
 | 
			
		||||
            "Firing wildly he shoots holes through the tanker. He",
 | 
			
		||||
            "He puts his spatula down. The Mercedes",
 | 
			
		||||
            "He stands and looks around, his eyes finally landing on: "
 | 
			
		||||
            "The digicam and a stack of cassettes on a shelf. Someone",
 | 
			
		||||
            "He starts going through someone's bureau. He opens the drawer "
 | 
			
		||||
            "in which we know someone keeps his marijuana, but he",
 | 
			
		||||
        ],
 | 
			
		||||
        "sent1": [
 | 
			
		||||
            "Firing wildly he shoots holes through the tanker.",
 | 
			
		||||
            "He puts his spatula down.",
 | 
			
		||||
            "He stands and looks around, his eyes finally landing on: "
 | 
			
		||||
            "The digicam and a stack of cassettes on a shelf.",
 | 
			
		||||
            "He starts going through someone's bureau.",
 | 
			
		||||
        ],
 | 
			
		||||
        "sent2": [
 | 
			
		||||
            "He",
 | 
			
		||||
            "The Mercedes",
 | 
			
		||||
            "Someone",
 | 
			
		||||
            "He opens the drawer in which we know someone keeps his marijuana, but he",
 | 
			
		||||
        ],
 | 
			
		||||
        "gold-source": ["gold", "gold", "gold", "gold"],
 | 
			
		||||
        "ending0": [
 | 
			
		||||
            "overtakes the rig and falls off his bike.",
 | 
			
		||||
            "fly open and drinks.",
 | 
			
		||||
            "looks at someone's papers.",
 | 
			
		||||
            "stops one down and rubs a piece of the gift out.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending1": [
 | 
			
		||||
            "squeezes relentlessly on the peanut jelly as well.",
 | 
			
		||||
            "walks off followed driveway again.",
 | 
			
		||||
            "feels around it and falls in the seat once more.",
 | 
			
		||||
            "cuts the mangled parts.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending2": [
 | 
			
		||||
            "scrambles behind himself and comes in other directions.",
 | 
			
		||||
            "slots them into a separate green.",
 | 
			
		||||
            "sprints back from the wreck and drops onto his back.",
 | 
			
		||||
            "hides it under his hat to watch.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending3": [
 | 
			
		||||
            "sweeps a explodes and knocks someone off.",
 | 
			
		||||
            "pulls around to the drive - thru window.",
 | 
			
		||||
            "sits at the kitchen table, staring off into space.",
 | 
			
		||||
            "does n't discover its false bottom.",
 | 
			
		||||
        ],
 | 
			
		||||
        "label": [0, 3, 3, 3],
 | 
			
		||||
    }
 | 
			
		||||
    test_data = {
 | 
			
		||||
        "video-id": [
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
            "lsmdc0001_American_Beauty-45991",
 | 
			
		||||
        ],
 | 
			
		||||
        "fold-ind": ["10980", "10976", "10978", "10969"],
 | 
			
		||||
        "startphrase": [
 | 
			
		||||
            "Someone leans out of the drive - thru window, "
 | 
			
		||||
            "grinning at her, holding bags filled with fast food. The Counter Girl",
 | 
			
		||||
            "Someone looks up suddenly when he hears. He",
 | 
			
		||||
            "Someone drives; someone sits beside her. They",
 | 
			
		||||
            "He opens the drawer in which we know someone "
 | 
			
		||||
            "keeps his marijuana, but he does n't discover"
 | 
			
		||||
            " its false bottom. He stands and looks around, his eyes",
 | 
			
		||||
        ],
 | 
			
		||||
        "sent1": [
 | 
			
		||||
            "Someone leans out of the drive - thru "
 | 
			
		||||
            "window, grinning at her, holding bags filled with fast food.",
 | 
			
		||||
            "Someone looks up suddenly when he hears.",
 | 
			
		||||
            "Someone drives; someone sits beside her.",
 | 
			
		||||
            "He opens the drawer in which we know"
 | 
			
		||||
            " someone keeps his marijuana, but he does n't discover its false bottom.",
 | 
			
		||||
        ],
 | 
			
		||||
        "sent2": [
 | 
			
		||||
            "The Counter Girl",
 | 
			
		||||
            "He",
 | 
			
		||||
            "They",
 | 
			
		||||
            "He stands and looks around, his eyes",
 | 
			
		||||
        ],
 | 
			
		||||
        "gold-source": ["gold", "gold", "gold", "gold"],
 | 
			
		||||
        "ending0": [
 | 
			
		||||
            "stands next to him, staring blankly.",
 | 
			
		||||
            "puts his spatula down.",
 | 
			
		||||
            "rise someone's feet up.",
 | 
			
		||||
            "moving to the side, the houses rapidly stained.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending1": [
 | 
			
		||||
            "with auditorium, filmed, singers the club.",
 | 
			
		||||
            "bumps into a revolver and drops surreptitiously into his weapon.",
 | 
			
		||||
            "lift her and they are alarmed.",
 | 
			
		||||
            "focused as the sight of someone making his way down a trail.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending2": [
 | 
			
		||||
            "attempts to block her ransacked.",
 | 
			
		||||
            "talks using the phone and walks away for a few seconds.",
 | 
			
		||||
            "are too involved with each other to "
 | 
			
		||||
            "notice someone watching them from the drive - thru window.",
 | 
			
		||||
            "finally landing on: the digicam and a stack of cassettes on a shelf.",
 | 
			
		||||
        ],
 | 
			
		||||
        "ending3": [
 | 
			
		||||
            "is eating solid and stinky.",
 | 
			
		||||
            "bundles the flaxen powder beneath the car.",
 | 
			
		||||
            "sit at a table with a beer from a table.",
 | 
			
		||||
            "deep and continuing, its bleed - length sideburns pressing on him.",
 | 
			
		||||
        ],
 | 
			
		||||
        "label": [0, 0, 2, 2],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    train_dataset = pd.DataFrame(train_data)
 | 
			
		||||
    dev_dataset = pd.DataFrame(dev_data)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										741
									
								
								test/nlp/test_autohf_tokenclassification.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										741
									
								
								test/nlp/test_autohf_tokenclassification.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,741 @@
 | 
			
		||||
import sys
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
 | 
			
		||||
def test_tokenclassification():
 | 
			
		||||
    from flaml import AutoML
 | 
			
		||||
    import pandas as pd
 | 
			
		||||
 | 
			
		||||
    train_data = {
 | 
			
		||||
        "chunk_tags": [
 | 
			
		||||
            [11, 21, 11, 12, 21, 22, 11, 12, 0],
 | 
			
		||||
            [11, 12],
 | 
			
		||||
            [11, 12],
 | 
			
		||||
            [
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                17,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                17,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
        "id": ["0", "1", "2", "3"],
 | 
			
		||||
        "ner_tags": [
 | 
			
		||||
            [3, 0, 7, 0, 0, 0, 7, 0, 0],
 | 
			
		||||
            [1, 2],
 | 
			
		||||
            [5, 0],
 | 
			
		||||
            [
 | 
			
		||||
                0,
 | 
			
		||||
                3,
 | 
			
		||||
                4,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                7,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                7,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
        "pos_tags": [
 | 
			
		||||
            [22, 42, 16, 21, 35, 37, 16, 21, 7],
 | 
			
		||||
            [22, 22],
 | 
			
		||||
            [22, 11],
 | 
			
		||||
            [
 | 
			
		||||
                12,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                38,
 | 
			
		||||
                15,
 | 
			
		||||
                22,
 | 
			
		||||
                28,
 | 
			
		||||
                38,
 | 
			
		||||
                15,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                35,
 | 
			
		||||
                24,
 | 
			
		||||
                35,
 | 
			
		||||
                37,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                15,
 | 
			
		||||
                24,
 | 
			
		||||
                41,
 | 
			
		||||
                15,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                21,
 | 
			
		||||
                20,
 | 
			
		||||
                37,
 | 
			
		||||
                40,
 | 
			
		||||
                35,
 | 
			
		||||
                21,
 | 
			
		||||
                7,
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
        "tokens": [
 | 
			
		||||
            [
 | 
			
		||||
                "EU",
 | 
			
		||||
                "rejects",
 | 
			
		||||
                "German",
 | 
			
		||||
                "call",
 | 
			
		||||
                "to",
 | 
			
		||||
                "boycott",
 | 
			
		||||
                "British",
 | 
			
		||||
                "lamb",
 | 
			
		||||
                ".",
 | 
			
		||||
            ],
 | 
			
		||||
            ["Peter", "Blackburn"],
 | 
			
		||||
            ["BRUSSELS", "1996-08-22"],
 | 
			
		||||
            [
 | 
			
		||||
                "The",
 | 
			
		||||
                "European",
 | 
			
		||||
                "Commission",
 | 
			
		||||
                "said",
 | 
			
		||||
                "on",
 | 
			
		||||
                "Thursday",
 | 
			
		||||
                "it",
 | 
			
		||||
                "disagreed",
 | 
			
		||||
                "with",
 | 
			
		||||
                "German",
 | 
			
		||||
                "advice",
 | 
			
		||||
                "to",
 | 
			
		||||
                "consumers",
 | 
			
		||||
                "to",
 | 
			
		||||
                "shun",
 | 
			
		||||
                "British",
 | 
			
		||||
                "lamb",
 | 
			
		||||
                "until",
 | 
			
		||||
                "scientists",
 | 
			
		||||
                "determine",
 | 
			
		||||
                "whether",
 | 
			
		||||
                "mad",
 | 
			
		||||
                "cow",
 | 
			
		||||
                "disease",
 | 
			
		||||
                "can",
 | 
			
		||||
                "be",
 | 
			
		||||
                "transmitted",
 | 
			
		||||
                "to",
 | 
			
		||||
                "sheep",
 | 
			
		||||
                ".",
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    dev_data = {
 | 
			
		||||
        "chunk_tags": [
 | 
			
		||||
            [
 | 
			
		||||
                11,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                1,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                17,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                1,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                0,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                17,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                0,
 | 
			
		||||
                17,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                17,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                0,
 | 
			
		||||
                11,
 | 
			
		||||
                0,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                13,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                1,
 | 
			
		||||
                2,
 | 
			
		||||
                2,
 | 
			
		||||
                11,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                11,
 | 
			
		||||
                12,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
        "id": ["4", "5", "6", "7"],
 | 
			
		||||
        "ner_tags": [
 | 
			
		||||
            [
 | 
			
		||||
                5,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                3,
 | 
			
		||||
                4,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                1,
 | 
			
		||||
                2,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                5,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                3,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                1,
 | 
			
		||||
                2,
 | 
			
		||||
                2,
 | 
			
		||||
                2,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0],
 | 
			
		||||
            [
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                3,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                1,
 | 
			
		||||
                2,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
                0,
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
        "pos_tags": [
 | 
			
		||||
            [
 | 
			
		||||
                22,
 | 
			
		||||
                27,
 | 
			
		||||
                21,
 | 
			
		||||
                35,
 | 
			
		||||
                12,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                27,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                38,
 | 
			
		||||
                15,
 | 
			
		||||
                22,
 | 
			
		||||
                24,
 | 
			
		||||
                20,
 | 
			
		||||
                37,
 | 
			
		||||
                21,
 | 
			
		||||
                15,
 | 
			
		||||
                24,
 | 
			
		||||
                16,
 | 
			
		||||
                15,
 | 
			
		||||
                22,
 | 
			
		||||
                15,
 | 
			
		||||
                12,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                38,
 | 
			
		||||
                17,
 | 
			
		||||
                7,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                0,
 | 
			
		||||
                28,
 | 
			
		||||
                41,
 | 
			
		||||
                30,
 | 
			
		||||
                37,
 | 
			
		||||
                12,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                15,
 | 
			
		||||
                28,
 | 
			
		||||
                41,
 | 
			
		||||
                30,
 | 
			
		||||
                37,
 | 
			
		||||
                12,
 | 
			
		||||
                24,
 | 
			
		||||
                15,
 | 
			
		||||
                28,
 | 
			
		||||
                6,
 | 
			
		||||
                0,
 | 
			
		||||
                12,
 | 
			
		||||
                22,
 | 
			
		||||
                27,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                14,
 | 
			
		||||
                22,
 | 
			
		||||
                38,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                21,
 | 
			
		||||
                7,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                28,
 | 
			
		||||
                38,
 | 
			
		||||
                16,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                38,
 | 
			
		||||
                40,
 | 
			
		||||
                10,
 | 
			
		||||
                15,
 | 
			
		||||
                28,
 | 
			
		||||
                38,
 | 
			
		||||
                40,
 | 
			
		||||
                15,
 | 
			
		||||
                21,
 | 
			
		||||
                38,
 | 
			
		||||
                40,
 | 
			
		||||
                28,
 | 
			
		||||
                20,
 | 
			
		||||
                37,
 | 
			
		||||
                40,
 | 
			
		||||
                15,
 | 
			
		||||
                12,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                7,
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                28,
 | 
			
		||||
                38,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                15,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                22,
 | 
			
		||||
                35,
 | 
			
		||||
                37,
 | 
			
		||||
                21,
 | 
			
		||||
                24,
 | 
			
		||||
                6,
 | 
			
		||||
                24,
 | 
			
		||||
                10,
 | 
			
		||||
                16,
 | 
			
		||||
                24,
 | 
			
		||||
                15,
 | 
			
		||||
                12,
 | 
			
		||||
                21,
 | 
			
		||||
                10,
 | 
			
		||||
                21,
 | 
			
		||||
                21,
 | 
			
		||||
                24,
 | 
			
		||||
                38,
 | 
			
		||||
                12,
 | 
			
		||||
                30,
 | 
			
		||||
                16,
 | 
			
		||||
                10,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                35,
 | 
			
		||||
                37,
 | 
			
		||||
                16,
 | 
			
		||||
                21,
 | 
			
		||||
                7,
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
        "tokens": [
 | 
			
		||||
            [
 | 
			
		||||
                "Germany",
 | 
			
		||||
                "'s",
 | 
			
		||||
                "representative",
 | 
			
		||||
                "to",
 | 
			
		||||
                "the",
 | 
			
		||||
                "European",
 | 
			
		||||
                "Union",
 | 
			
		||||
                "'s",
 | 
			
		||||
                "veterinary",
 | 
			
		||||
                "committee",
 | 
			
		||||
                "Werner",
 | 
			
		||||
                "Zwingmann",
 | 
			
		||||
                "said",
 | 
			
		||||
                "on",
 | 
			
		||||
                "Wednesday",
 | 
			
		||||
                "consumers",
 | 
			
		||||
                "should",
 | 
			
		||||
                "buy",
 | 
			
		||||
                "sheepmeat",
 | 
			
		||||
                "from",
 | 
			
		||||
                "countries",
 | 
			
		||||
                "other",
 | 
			
		||||
                "than",
 | 
			
		||||
                "Britain",
 | 
			
		||||
                "until",
 | 
			
		||||
                "the",
 | 
			
		||||
                "scientific",
 | 
			
		||||
                "advice",
 | 
			
		||||
                "was",
 | 
			
		||||
                "clearer",
 | 
			
		||||
                ".",
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                '"',
 | 
			
		||||
                "We",
 | 
			
		||||
                "do",
 | 
			
		||||
                "n't",
 | 
			
		||||
                "support",
 | 
			
		||||
                "any",
 | 
			
		||||
                "such",
 | 
			
		||||
                "recommendation",
 | 
			
		||||
                "because",
 | 
			
		||||
                "we",
 | 
			
		||||
                "do",
 | 
			
		||||
                "n't",
 | 
			
		||||
                "see",
 | 
			
		||||
                "any",
 | 
			
		||||
                "grounds",
 | 
			
		||||
                "for",
 | 
			
		||||
                "it",
 | 
			
		||||
                ",",
 | 
			
		||||
                '"',
 | 
			
		||||
                "the",
 | 
			
		||||
                "Commission",
 | 
			
		||||
                "'s",
 | 
			
		||||
                "chief",
 | 
			
		||||
                "spokesman",
 | 
			
		||||
                "Nikolaus",
 | 
			
		||||
                "van",
 | 
			
		||||
                "der",
 | 
			
		||||
                "Pas",
 | 
			
		||||
                "told",
 | 
			
		||||
                "a",
 | 
			
		||||
                "news",
 | 
			
		||||
                "briefing",
 | 
			
		||||
                ".",
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                "He",
 | 
			
		||||
                "said",
 | 
			
		||||
                "further",
 | 
			
		||||
                "scientific",
 | 
			
		||||
                "study",
 | 
			
		||||
                "was",
 | 
			
		||||
                "required",
 | 
			
		||||
                "and",
 | 
			
		||||
                "if",
 | 
			
		||||
                "it",
 | 
			
		||||
                "was",
 | 
			
		||||
                "found",
 | 
			
		||||
                "that",
 | 
			
		||||
                "action",
 | 
			
		||||
                "was",
 | 
			
		||||
                "needed",
 | 
			
		||||
                "it",
 | 
			
		||||
                "should",
 | 
			
		||||
                "be",
 | 
			
		||||
                "taken",
 | 
			
		||||
                "by",
 | 
			
		||||
                "the",
 | 
			
		||||
                "European",
 | 
			
		||||
                "Union",
 | 
			
		||||
                ".",
 | 
			
		||||
            ],
 | 
			
		||||
            [
 | 
			
		||||
                "He",
 | 
			
		||||
                "said",
 | 
			
		||||
                "a",
 | 
			
		||||
                "proposal",
 | 
			
		||||
                "last",
 | 
			
		||||
                "month",
 | 
			
		||||
                "by",
 | 
			
		||||
                "EU",
 | 
			
		||||
                "Farm",
 | 
			
		||||
                "Commissioner",
 | 
			
		||||
                "Franz",
 | 
			
		||||
                "Fischler",
 | 
			
		||||
                "to",
 | 
			
		||||
                "ban",
 | 
			
		||||
                "sheep",
 | 
			
		||||
                "brains",
 | 
			
		||||
                ",",
 | 
			
		||||
                "spleens",
 | 
			
		||||
                "and",
 | 
			
		||||
                "spinal",
 | 
			
		||||
                "cords",
 | 
			
		||||
                "from",
 | 
			
		||||
                "the",
 | 
			
		||||
                "human",
 | 
			
		||||
                "and",
 | 
			
		||||
                "animal",
 | 
			
		||||
                "food",
 | 
			
		||||
                "chains",
 | 
			
		||||
                "was",
 | 
			
		||||
                "a",
 | 
			
		||||
                "highly",
 | 
			
		||||
                "specific",
 | 
			
		||||
                "and",
 | 
			
		||||
                "precautionary",
 | 
			
		||||
                "move",
 | 
			
		||||
                "to",
 | 
			
		||||
                "protect",
 | 
			
		||||
                "human",
 | 
			
		||||
                "health",
 | 
			
		||||
                ".",
 | 
			
		||||
            ],
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    train_dataset = pd.DataFrame(train_data)
 | 
			
		||||
    dev_dataset = pd.DataFrame(dev_data)
 | 
			
		||||
 | 
			
		||||
    custom_sent_keys = ["tokens"]
 | 
			
		||||
    label_key = "ner_tags"
 | 
			
		||||
 | 
			
		||||
    X_train = train_dataset[custom_sent_keys]
 | 
			
		||||
    y_train = train_dataset[label_key]
 | 
			
		||||
 | 
			
		||||
    X_val = dev_dataset[custom_sent_keys]
 | 
			
		||||
    y_val = dev_dataset[label_key]
 | 
			
		||||
 | 
			
		||||
    automl = AutoML()
 | 
			
		||||
 | 
			
		||||
    automl_settings = {
 | 
			
		||||
        "gpu_per_trial": 0,
 | 
			
		||||
        "max_iter": 2,
 | 
			
		||||
        "time_budget": 5,
 | 
			
		||||
        "task": "token-classification",
 | 
			
		||||
        "metric": "seqeval",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    automl_settings["custom_hpo_args"] = {
 | 
			
		||||
        "model_path": "bert-base-uncased",
 | 
			
		||||
        "output_dir": "test/data/output/",
 | 
			
		||||
        "ckpt_per_epoch": 5,
 | 
			
		||||
        "fp16": False,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    automl.fit(
 | 
			
		||||
        X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    test_tokenclassification()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user