mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-30 17:29:47 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			69 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			69 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from collections import OrderedDict
 | |
| 
 | |
| import transformers
 | |
| 
 | |
| if transformers.__version__.startswith("3"):
 | |
|     from transformers.modeling_electra import ElectraClassificationHead
 | |
|     from transformers.modeling_roberta import RobertaClassificationHead
 | |
|     from transformers.models.electra.modeling_electra import ElectraForTokenClassification
 | |
|     from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
 | |
| 
 | |
| else:
 | |
|     from transformers.models.electra.modeling_electra import ElectraClassificationHead
 | |
|     from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
 | |
|     from transformers.models.electra.modeling_electra import ElectraForTokenClassification
 | |
|     from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
 | |
| 
 | |
| MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
 | |
|     [
 | |
|         ("electra", ElectraClassificationHead),
 | |
|         ("roberta", RobertaClassificationHead),
 | |
|     ]
 | |
| )
 | |
| 
 | |
| 
 | |
| class AutoSeqClassificationHead:
 | |
|     """
 | |
|     This is a class for getting classification head class based on the name of the LM
 | |
|     instantiated as one of the ClassificationHead classes of the library when
 | |
|     created with the `AutoSeqClassificationHead.from_model_type_and_config` method.
 | |
| 
 | |
|     This class cannot be instantiated directly using ``__init__()`` (throws an error).
 | |
|     """
 | |
| 
 | |
|     def __init__(self):
 | |
|         raise EnvironmentError(
 | |
|             "AutoSeqClassificationHead is designed to be instantiated "
 | |
|             "using the `AutoSeqClassificationHead.from_model_type_and_config(cls, model_type, config)` methods."
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def from_model_type_and_config(
 | |
|         cls, model_type: str, config: transformers.PretrainedConfig
 | |
|     ):
 | |
|         """
 | |
|         Instantiate one of the classification head classes from the mode_type and model configuration.
 | |
| 
 | |
|         Args:
 | |
|             model_type: A string, which desribes the model type, e.g., "electra".
 | |
|             config: The huggingface class of the model's configuration.
 | |
| 
 | |
|         Example:
 | |
| 
 | |
|         ```python
 | |
|         from transformers import AutoConfig
 | |
|         model_config = AutoConfig.from_pretrained("google/electra-base-discriminator")
 | |
|         AutoSeqClassificationHead.from_model_type_and_config("electra", model_config)
 | |
|         ```
 | |
|         """
 | |
|         if model_type in MODEL_CLASSIFICATION_HEAD_MAPPING.keys():
 | |
|             return MODEL_CLASSIFICATION_HEAD_MAPPING[model_type](config)
 | |
|         raise ValueError(
 | |
|             "Unrecognized configuration class {} for class {}.\n"
 | |
|             "Model type should be one of {}.".format(
 | |
|                 config.__class__,
 | |
|                 cls.__name__,
 | |
|                 ", ".join(MODEL_CLASSIFICATION_HEAD_MAPPING.keys()),
 | |
|             )
 | |
|         )
 | 
