Fix finetuning notebook augmentation (#2071)

* fix data augmentation path in finetuning notebook

* Add latest docstring and tutorial changes

* make distillation possible with other models than BERT

* use smaller dataset for distillation in finetuning tutorial

* Add latest docstring and tutorial changes

* make data augmentation in finetuning faster

* update language models forward doc strings

* fix return type of language models

* remove debug output

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-01-26 17:49:14 +01:00 committed by GitHub
parent c4fff19018
commit 4cc37548e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 129 additions and 57 deletions

View File

@ -104,8 +104,16 @@ To get the most out of model distillation, we recommend increasing the size of y
```python ```python
# Downloading script # Downloading script
!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py !wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py
# Just replace the path with your dataset and adjust the output
!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2 # Downloading smaller glove vector file (only for demonstration purposes)
!wget https://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip
# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)
!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json
# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)
!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt
``` ```
In this case, we use a multiplication factor of 2 to keep this example lightweight. Usually you would use a factor like 20 depending on the size of your training data. Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU. In this case, we use a multiplication factor of 2 to keep this example lightweight. Usually you would use a factor like 20 depending on the size of your training data. Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.
@ -124,7 +132,7 @@ teacher = FARMReader(model_name_or_path="my_model", use_gpu=True)
# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student. # The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D", use_gpu=True) student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D", use_gpu=True)
student.distil_intermediate_layers_from(teacher, data_dir="data/squad20", train_filename="augmented_dataset.json", use_gpu=True) student.distil_intermediate_layers_from(teacher, data_dir=".", train_filename="augmented_dataset.json", use_gpu=True)
student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True) student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True)
student.save(directory="my_distilled_model") student.save(directory="my_distilled_model")

View File

@ -572,6 +572,8 @@ class Albert(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -583,19 +585,24 @@ class Albert(LanguageModel):
It is a tensor of shape [batch_size, max_seq_len] It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
if self.model.encoder.config.output_hidden_states == True: return output_tuple
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True self.model.encoder.config.output_hidden_states = True
@ -654,6 +661,8 @@ class Roberta(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -665,19 +674,24 @@ class Roberta(LanguageModel):
It is a tensor of shape [batch_size, max_seq_len] It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
if self.model.encoder.config.output_hidden_states == True: return output_tuple
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True self.model.encoder.config.output_hidden_states = True
@ -736,6 +750,8 @@ class XLMRoberta(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -747,19 +763,24 @@ class XLMRoberta(LanguageModel):
It is a tensor of shape [batch_size, max_seq_len] It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
if self.model.encoder.config.output_hidden_states == True: return output_tuple
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True self.model.encoder.config.output_hidden_states = True
@ -832,6 +853,8 @@ class DistilBert(LanguageModel):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -840,20 +863,25 @@ class DistilBert(LanguageModel):
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
# We need to manually aggregate that to get a pooled output (one vec per seq) # We need to manually aggregate that to get a pooled output (one vec per seq)
pooled_output = self.pooler(output_tuple[0]) pooled_output = self.pooler(output_tuple[0])
if self.model.config.output_hidden_states == True: return (output_tuple[0], pooled_output) + output_tuple[1:]
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
else:
sequence_output = output_tuple[0]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.config.output_hidden_states = True self.model.config.output_hidden_states = True
@ -921,6 +949,8 @@ class XLNet(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -932,26 +962,29 @@ class XLNet(LanguageModel):
It is a tensor of shape [batch_size, max_seq_len] It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
# Note: XLNet has a couple of special input tensors for pretraining / text generation (perm_mask, target_mapping ...) # Note: XLNet has a couple of special input tensors for pretraining / text generation (perm_mask, target_mapping ...)
# We will need to implement them, if we wanna support LM adaptation # We will need to implement them, if we wanna support LM adaptation
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
# XLNet also only returns the sequence_output (one vec per token) # XLNet also only returns the sequence_output (one vec per token)
# We need to manually aggregate that to get a pooled output (one vec per seq) # We need to manually aggregate that to get a pooled output (one vec per seq)
# TODO verify that this is really doing correct pooling # TODO verify that this is really doing correct pooling
pooled_output = self.pooler(output_tuple[0]) pooled_output = self.pooler(output_tuple[0])
return (output_tuple[0], pooled_output) + output_tuple[1:]
if self.model.output_hidden_states == True:
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output = output_tuple[0]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.output_hidden_states = True self.model.output_hidden_states = True
@ -1030,6 +1063,8 @@ class Electra(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
@ -1038,26 +1073,31 @@ class Electra(LanguageModel):
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
return_dict=False
) )
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model(
input_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
# We need to manually aggregate that to get a pooled output (one vec per seq) # We need to manually aggregate that to get a pooled output (one vec per seq)
pooled_output = self.pooler(output_tuple[0]) pooled_output = self.pooler(output_tuple[0])
return (output_tuple[0], pooled_output) + output_tuple[1:]
if self.model.config.output_hidden_states == True:
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
else:
sequence_output = output_tuple[0]
return sequence_output, pooled_output
def enable_hidden_states_output(self):
self.model.config.output_hidden_states = True
def disable_hidden_states_output(self): def disable_hidden_states_output(self):
self.model.config.output_hidden_states = False self.model.config.output_hidden_states = False
@ -1439,10 +1479,12 @@ class BigBird(LanguageModel):
input_ids: torch.Tensor, input_ids: torch.Tensor,
segment_ids: torch.Tensor, segment_ids: torch.Tensor,
padding_mask: torch.Tensor, padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs, **kwargs,
): ):
""" """
Perform the forward pass of the BERT model. Perform the forward pass of the BigBird model.
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
:param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
@ -1450,19 +1492,24 @@ class BigBird(LanguageModel):
It is a tensor of shape [batch_size, max_seq_len] It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len] of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence. :return: Embeddings for each token in the input sequence.
""" """
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions
output_tuple = self.model( output_tuple = self.model(
input_ids, input_ids,
token_type_ids=segment_ids, token_type_ids=segment_ids,
attention_mask=padding_mask, attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=False
) )
if self.model.encoder.config.output_hidden_states == True: return output_tuple
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
def enable_hidden_states_output(self): def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True self.model.encoder.config.output_hidden_states = True

View File

@ -188,8 +188,16 @@
"source": [ "source": [
"# Downloading script\n", "# Downloading script\n",
"!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py\n", "!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py\n",
"# Just replace the path with your dataset and adjust the output\n", "\n",
"!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2" "# Downloading smaller glove vector file (only for demonstration purposes)\n",
"!wget https://nlp.stanford.edu/data/glove.6B.zip\n",
"!unzip glove.6B.zip\n",
"\n",
"# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)\n",
"!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json\n",
"\n",
"# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)\n",
"!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt"
] ]
}, },
{ {
@ -217,7 +225,7 @@
"# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.\n", "# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.\n",
"student = FARMReader(model_name_or_path=\"huawei-noah/TinyBERT_General_6L_768D\", use_gpu=True)\n", "student = FARMReader(model_name_or_path=\"huawei-noah/TinyBERT_General_6L_768D\", use_gpu=True)\n",
"\n", "\n",
"student.distil_intermediate_layers_from(teacher, data_dir=\"data/squad20\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n", "student.distil_intermediate_layers_from(teacher, data_dir=\".\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n",
"student.distil_prediction_layer_from(teacher, data_dir=\"data/squad20\", train_filename=\"dev-v2.0.json\", use_gpu=True)\n", "student.distil_prediction_layer_from(teacher, data_dir=\"data/squad20\", train_filename=\"dev-v2.0.json\", use_gpu=True)\n",
"\n", "\n",
"student.save(directory=\"my_distilled_model\")" "student.save(directory=\"my_distilled_model\")"

View File

@ -12,6 +12,7 @@ from haystack.utils import augment_squad
from pathlib import Path from pathlib import Path
import os
def tutorial2_finetune_a_model_on_your_data(): def tutorial2_finetune_a_model_on_your_data():
# ## Create Training Data # ## Create Training Data
@ -65,8 +66,16 @@ def distil():
# ### Augmenting your training data # ### Augmenting your training data
# To get the most out of model distillation, we recommend increasing the size of your training data by using data augmentation. # To get the most out of model distillation, we recommend increasing the size of your training data by using data augmentation.
# You can do this by running the [`augment_squad.py` script](https://github.com/deepset-ai/haystack/blob/master/haystack/utils/augment_squad.py): # You can do this by running the [`augment_squad.py` script](https://github.com/deepset-ai/haystack/blob/master/haystack/utils/augment_squad.py):
# # Just replace dataset.json with the name of your dataset and adjust the output path
augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2) # Downloading smaller glove vector file (only for demonstration purposes)
os.system("wget https://nlp.stanford.edu/data/glove.6B.zip")
os.system("unzip glove.6B.zip")
# Downloading very small dataset to make tutorial faster (please use a bigger dataset in real use cases)
os.system("wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json")
# Just replace dataset.json with the name of your dataset and adjust the output path
augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2, glove_path=Path("glove.6B.300d.txt"))
# In this case, we use a multiplication factor of 2 to keep this example lightweight. # In this case, we use a multiplication factor of 2 to keep this example lightweight.
# Usually you would use a factor like 20 depending on the size of your training data. # Usually you would use a factor like 20 depending on the size of your training data.
# Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU. # Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.