diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index a90202dd4..7286172b8 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -48,6 +48,8 @@ from transformers import ( CamembertConfig, BigBirdModel, BigBirdConfig, + DebertaV2Model, + DebertaV2Config, ) from transformers import AutoModel, AutoConfig from transformers.modeling_utils import SequenceSummary @@ -238,6 +240,8 @@ class LanguageModel(nn.Module): raise NotImplementedError("DPRReader models are currently not supported.") elif model_type == "big_bird": language_model_class = "BigBird" + elif model_type == "deberta-v2": + language_model_class = "DebertaV2" else: # Fall back to inferring type from model name logger.warning( @@ -1570,3 +1574,101 @@ class BigBird(LanguageModel): def disable_hidden_states_output(self): self.model.encoder.config.output_hidden_states = False + + +class DebertaV2(LanguageModel): + """ + This is a wrapper around the DebertaV2 model from HuggingFace's transformers library. + It is also compatible with DebertaV3 as DebertaV3 only changes the pretraining procedure. + + NOTE: + - DebertaV2 does not output the pooled_output. An additional pooler is initialized. + """ + + def __init__(self): + super().__init__() + self.model = None + self.name = "deberta-v2" + self.pooler = None + + @classmethod + @silence_transformers_logs + def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs): + """ + Load a pretrained model by supplying + + * a remote name from Huggingface's modelhub ("microsoft/deberta-v3-base" ...) + * OR a local path of a model trained via transformers ("some_dir/huggingface_model") + * OR a local path of a model trained via Haystack ("some_dir/haystack_model") + + :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. + """ + debertav2 = cls() + if "haystack_lm_name" in kwargs: + debertav2.name = kwargs["haystack_lm_name"] + else: + debertav2.name = pretrained_model_name_or_path + # We need to differentiate between loading model using Haystack format and Transformers format + haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" + if os.path.exists(haystack_lm_config): + # Haystack style + config = DebertaV2Config.from_pretrained(haystack_lm_config) + haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" + debertav2.model = DebertaV2Model.from_pretrained(haystack_lm_model, config=config, **kwargs) + debertav2.language = debertav2.model.config.language + else: + # Transformers Style + debertav2.model = DebertaV2Model.from_pretrained(str(pretrained_model_name_or_path), **kwargs) + debertav2.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) + config = debertav2.model.config + + # DebertaV2 does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. + # The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim). + # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we + # feed everything to the prediction head. + config.summary_last_dropout = 0 + config.summary_type = "first" + config.summary_activation = "tanh" + config.summary_use_proj = False + debertav2.pooler = SequenceSummary(config) + debertav2.pooler.apply(debertav2.model._init_weights) + return debertav2 + + def forward( + self, + input_ids: torch.Tensor, + segment_ids: torch.Tensor, + padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + **kwargs, + ): + """ + Perform the forward pass of the DebertaV2 model. + + :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] + :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens + of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings + :return: Embeddings for each token in the input sequence. + """ + output_tuple = self.model(input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, return_dict=False) + + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + + output_tuple = self.model( + input_ids, + attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + # We need to manually aggregate that to get a pooled output (one vec per seq) + pooled_output = self.pooler(output_tuple[0]) + return (output_tuple[0], pooled_output) + output_tuple[1:] + + def disable_hidden_states_output(self): + self.model.config.output_hidden_states = False diff --git a/haystack/modeling/model/tokenization.py b/haystack/modeling/model/tokenization.py index 0f2e8710f..db1f2a1bc 100644 --- a/haystack/modeling/model/tokenization.py +++ b/haystack/modeling/model/tokenization.py @@ -44,6 +44,8 @@ from transformers import ( DPRQuestionEncoderTokenizerFast, BigBirdTokenizer, BigBirdTokenizerFast, + DebertaV2Tokenizer, + DebertaV2TokenizerFast, ) from transformers import AutoConfig @@ -197,6 +199,15 @@ class Tokenizer: ret = BigBirdTokenizer.from_pretrained( pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs ) + elif "DebertaV2Tokenizer" in tokenizer_class: + if use_fast: + ret = DebertaV2TokenizerFast.from_pretrained( + pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs + ) + else: + ret = DebertaV2Tokenizer.from_pretrained( + pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs + ) if ret is None: raise Exception("Unable to load tokenizer") return ret @@ -246,6 +257,8 @@ class Tokenizer: raise NotImplementedError("DPRReader models are currently not supported.") elif model_type == "big_bird": tokenizer_class = "BigBirdTokenizer" + elif model_type == "deberta-v2": + tokenizer_class = "DebertaV2Tokenizer" else: # Fall back to inferring type from model name logger.warning( @@ -275,6 +288,10 @@ class Tokenizer: tokenizer_class = "CamembertTokenizer" elif "distilbert" in pretrained_model_name_or_path.lower(): tokenizer_class = "DistilBertTokenizer" + elif ( + "debertav2" in pretrained_model_name_or_path.lower() or "debertav3" in pretrained_model_name_or_path.lower() + ): + tokenizer_class = "DebertaV2Tokenizer" elif "bert" in pretrained_model_name_or_path.lower(): tokenizer_class = "BertTokenizer" elif "xlnet" in pretrained_model_name_or_path.lower(): diff --git a/haystack/modeling/training/base.py b/haystack/modeling/training/base.py index e029e9732..67a126c2f 100644 --- a/haystack/modeling/training/base.py +++ b/haystack/modeling/training/base.py @@ -18,6 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS from haystack.modeling.evaluation.eval import Evaluator from haystack.modeling.model.adaptive_model import AdaptiveModel from haystack.modeling.model.optimization import get_scheduler +from haystack.modeling.model.language_model import DebertaV2 from haystack.modeling.utils import GracefulKiller from haystack.utils.experiment_tracking import Tracker as tracker @@ -250,7 +251,9 @@ class Trainer: vocab_size1=len(self.data_silo.processor.query_tokenizer), vocab_size2=len(self.data_silo.processor.passage_tokenizer), ) - else: + elif not isinstance( + self.model.language_model, DebertaV2 + ): # DebertaV2 has mismatched vocab size on purpose (see https://github.com/huggingface/transformers/issues/12428) self.model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer)) self.model.train() diff --git a/setup.cfg b/setup.cfg index 9d8d1b96a..9afddfff1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,7 +56,7 @@ install_requires = torch>1.9,<1.12 requests pydantic - transformers==4.18.0 + transformers==4.19.2 nltk pandas