autogen/flaml/nlp/huggingface/switch_head_auto.py
Xueqing Liu 4814091d87
remove redundant imports (#426)
* remove redundant imports

* getting ride of hf dataset
2022-01-24 14:24:14 -08:00

65 lines
2.3 KiB
Python

from collections import OrderedDict
import transformers
if transformers.__version__.startswith("3"):
from transformers.modeling_electra import ElectraClassificationHead
from transformers.modeling_roberta import RobertaClassificationHead
else:
from transformers.models.electra.modeling_electra import ElectraClassificationHead
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
[
("electra", ElectraClassificationHead),
("roberta", RobertaClassificationHead),
]
)
class AutoSeqClassificationHead:
"""
This is a class for getting classification head class based on the name of the LM
instantiated as one of the ClassificationHead classes of the library when
created with the `AutoSeqClassificationHead.from_model_type_and_config` method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoSeqClassificationHead is designed to be instantiated "
"using the `AutoSeqClassificationHead.from_model_type_and_config(cls, model_type, config)` methods."
)
@classmethod
def from_model_type_and_config(
cls, model_type: str, config: transformers.PretrainedConfig
):
"""
Instantiate one of the classification head classes from the mode_type and model configuration.
Args:
model_type: A string, which desribes the model type, e.g., "electra".
config: The huggingface class of the model's configuration.
Example:
```python
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained("google/electra-base-discriminator")
AutoSeqClassificationHead.from_model_type_and_config("electra", model_config)
```
"""
if model_type in MODEL_CLASSIFICATION_HEAD_MAPPING.keys():
return MODEL_CLASSIFICATION_HEAD_MAPPING[model_type](config)
raise ValueError(
"Unrecognized configuration class {} for class {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(MODEL_CLASSIFICATION_HEAD_MAPPING.keys()),
)
)