mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-14 18:52:11 +00:00
Add DeBERTaV2/V3 support (#2097)
* add debertav2/v3 * update comments * Apply Black * assume support for fast deberta tokenizer * Apply Black * update required transformers version for deberta * fix mismatched vocab error * Update Documentation & Code Style * update debertav2 doc string Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
a9a4156731
commit
16b0fdd804
@ -48,6 +48,8 @@ from transformers import (
|
||||
CamembertConfig,
|
||||
BigBirdModel,
|
||||
BigBirdConfig,
|
||||
DebertaV2Model,
|
||||
DebertaV2Config,
|
||||
)
|
||||
from transformers import AutoModel, AutoConfig
|
||||
from transformers.modeling_utils import SequenceSummary
|
||||
@ -238,6 +240,8 @@ class LanguageModel(nn.Module):
|
||||
raise NotImplementedError("DPRReader models are currently not supported.")
|
||||
elif model_type == "big_bird":
|
||||
language_model_class = "BigBird"
|
||||
elif model_type == "deberta-v2":
|
||||
language_model_class = "DebertaV2"
|
||||
else:
|
||||
# Fall back to inferring type from model name
|
||||
logger.warning(
|
||||
@ -1570,3 +1574,101 @@ class BigBird(LanguageModel):
|
||||
|
||||
def disable_hidden_states_output(self):
|
||||
self.model.encoder.config.output_hidden_states = False
|
||||
|
||||
|
||||
class DebertaV2(LanguageModel):
|
||||
"""
|
||||
This is a wrapper around the DebertaV2 model from HuggingFace's transformers library.
|
||||
It is also compatible with DebertaV3 as DebertaV3 only changes the pretraining procedure.
|
||||
|
||||
NOTE:
|
||||
- DebertaV2 does not output the pooled_output. An additional pooler is initialized.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = None
|
||||
self.name = "deberta-v2"
|
||||
self.pooler = None
|
||||
|
||||
@classmethod
|
||||
@silence_transformers_logs
|
||||
def load(cls, pretrained_model_name_or_path: Union[Path, str], language: str = None, **kwargs):
|
||||
"""
|
||||
Load a pretrained model by supplying
|
||||
|
||||
* a remote name from Huggingface's modelhub ("microsoft/deberta-v3-base" ...)
|
||||
* OR a local path of a model trained via transformers ("some_dir/huggingface_model")
|
||||
* OR a local path of a model trained via Haystack ("some_dir/haystack_model")
|
||||
|
||||
:param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
|
||||
"""
|
||||
debertav2 = cls()
|
||||
if "haystack_lm_name" in kwargs:
|
||||
debertav2.name = kwargs["haystack_lm_name"]
|
||||
else:
|
||||
debertav2.name = pretrained_model_name_or_path
|
||||
# We need to differentiate between loading model using Haystack format and Transformers format
|
||||
haystack_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
|
||||
if os.path.exists(haystack_lm_config):
|
||||
# Haystack style
|
||||
config = DebertaV2Config.from_pretrained(haystack_lm_config)
|
||||
haystack_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
|
||||
debertav2.model = DebertaV2Model.from_pretrained(haystack_lm_model, config=config, **kwargs)
|
||||
debertav2.language = debertav2.model.config.language
|
||||
else:
|
||||
# Transformers Style
|
||||
debertav2.model = DebertaV2Model.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
|
||||
debertav2.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
|
||||
config = debertav2.model.config
|
||||
|
||||
# DebertaV2 does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler.
|
||||
# The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim).
|
||||
# We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we
|
||||
# feed everything to the prediction head.
|
||||
config.summary_last_dropout = 0
|
||||
config.summary_type = "first"
|
||||
config.summary_activation = "tanh"
|
||||
config.summary_use_proj = False
|
||||
debertav2.pooler = SequenceSummary(config)
|
||||
debertav2.pooler.apply(debertav2.model._init_weights)
|
||||
return debertav2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Perform the forward pass of the DebertaV2 model.
|
||||
|
||||
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
output_tuple = self.model(input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, return_dict=False)
|
||||
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# We need to manually aggregate that to get a pooled output (one vec per seq)
|
||||
pooled_output = self.pooler(output_tuple[0])
|
||||
return (output_tuple[0], pooled_output) + output_tuple[1:]
|
||||
|
||||
def disable_hidden_states_output(self):
|
||||
self.model.config.output_hidden_states = False
|
||||
|
@ -44,6 +44,8 @@ from transformers import (
|
||||
DPRQuestionEncoderTokenizerFast,
|
||||
BigBirdTokenizer,
|
||||
BigBirdTokenizerFast,
|
||||
DebertaV2Tokenizer,
|
||||
DebertaV2TokenizerFast,
|
||||
)
|
||||
from transformers import AutoConfig
|
||||
|
||||
@ -197,6 +199,15 @@ class Tokenizer:
|
||||
ret = BigBirdTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs
|
||||
)
|
||||
elif "DebertaV2Tokenizer" in tokenizer_class:
|
||||
if use_fast:
|
||||
ret = DebertaV2TokenizerFast.from_pretrained(
|
||||
pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs
|
||||
)
|
||||
else:
|
||||
ret = DebertaV2Tokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, use_auth_token=use_auth_token, **kwargs
|
||||
)
|
||||
if ret is None:
|
||||
raise Exception("Unable to load tokenizer")
|
||||
return ret
|
||||
@ -246,6 +257,8 @@ class Tokenizer:
|
||||
raise NotImplementedError("DPRReader models are currently not supported.")
|
||||
elif model_type == "big_bird":
|
||||
tokenizer_class = "BigBirdTokenizer"
|
||||
elif model_type == "deberta-v2":
|
||||
tokenizer_class = "DebertaV2Tokenizer"
|
||||
else:
|
||||
# Fall back to inferring type from model name
|
||||
logger.warning(
|
||||
@ -275,6 +288,10 @@ class Tokenizer:
|
||||
tokenizer_class = "CamembertTokenizer"
|
||||
elif "distilbert" in pretrained_model_name_or_path.lower():
|
||||
tokenizer_class = "DistilBertTokenizer"
|
||||
elif (
|
||||
"debertav2" in pretrained_model_name_or_path.lower() or "debertav3" in pretrained_model_name_or_path.lower()
|
||||
):
|
||||
tokenizer_class = "DebertaV2Tokenizer"
|
||||
elif "bert" in pretrained_model_name_or_path.lower():
|
||||
tokenizer_class = "BertTokenizer"
|
||||
elif "xlnet" in pretrained_model_name_or_path.lower():
|
||||
|
@ -18,6 +18,7 @@ from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataS
|
||||
from haystack.modeling.evaluation.eval import Evaluator
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.model.optimization import get_scheduler
|
||||
from haystack.modeling.model.language_model import DebertaV2
|
||||
from haystack.modeling.utils import GracefulKiller
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
@ -250,7 +251,9 @@ class Trainer:
|
||||
vocab_size1=len(self.data_silo.processor.query_tokenizer),
|
||||
vocab_size2=len(self.data_silo.processor.passage_tokenizer),
|
||||
)
|
||||
else:
|
||||
elif not isinstance(
|
||||
self.model.language_model, DebertaV2
|
||||
): # DebertaV2 has mismatched vocab size on purpose (see https://github.com/huggingface/transformers/issues/12428)
|
||||
self.model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer))
|
||||
self.model.train()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user