mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-20 06:28:39 +00:00
Fix finetuning notebook augmentation (#2071)
* fix data augmentation path in finetuning notebook * Add latest docstring and tutorial changes * make distillation possible with other models than BERT * use smaller dataset for distillation in finetuning tutorial * Add latest docstring and tutorial changes * make data augmentation in finetuning faster * update language models forward doc strings * fix return type of language models * remove debug output Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
c4fff19018
commit
4cc37548e3
@ -104,8 +104,16 @@ To get the most out of model distillation, we recommend increasing the size of y
|
||||
```python
|
||||
# Downloading script
|
||||
!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py
|
||||
# Just replace the path with your dataset and adjust the output
|
||||
!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2
|
||||
|
||||
# Downloading smaller glove vector file (only for demonstration purposes)
|
||||
!wget https://nlp.stanford.edu/data/glove.6B.zip
|
||||
!unzip glove.6B.zip
|
||||
|
||||
# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)
|
||||
!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json
|
||||
|
||||
# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)
|
||||
!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt
|
||||
```
|
||||
|
||||
In this case, we use a multiplication factor of 2 to keep this example lightweight. Usually you would use a factor like 20 depending on the size of your training data. Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.
|
||||
@ -124,7 +132,7 @@ teacher = FARMReader(model_name_or_path="my_model", use_gpu=True)
|
||||
# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.
|
||||
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D", use_gpu=True)
|
||||
|
||||
student.distil_intermediate_layers_from(teacher, data_dir="data/squad20", train_filename="augmented_dataset.json", use_gpu=True)
|
||||
student.distil_intermediate_layers_from(teacher, data_dir=".", train_filename="augmented_dataset.json", use_gpu=True)
|
||||
student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True)
|
||||
|
||||
student.save(directory="my_distilled_model")
|
||||
|
@ -572,6 +572,8 @@ class Albert(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -583,19 +585,24 @@ class Albert(LanguageModel):
|
||||
It is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
if self.model.encoder.config.output_hidden_states == True:
|
||||
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
|
||||
return sequence_output, pooled_output, all_hidden_states
|
||||
else:
|
||||
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
return output_tuple
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.encoder.config.output_hidden_states = True
|
||||
@ -654,6 +661,8 @@ class Roberta(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -665,19 +674,24 @@ class Roberta(LanguageModel):
|
||||
It is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
if self.model.encoder.config.output_hidden_states == True:
|
||||
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
|
||||
return sequence_output, pooled_output, all_hidden_states
|
||||
else:
|
||||
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
return output_tuple
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.encoder.config.output_hidden_states = True
|
||||
@ -736,6 +750,8 @@ class XLMRoberta(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -747,19 +763,24 @@ class XLMRoberta(LanguageModel):
|
||||
It is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
if self.model.encoder.config.output_hidden_states == True:
|
||||
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
|
||||
return sequence_output, pooled_output, all_hidden_states
|
||||
else:
|
||||
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
return output_tuple
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.encoder.config.output_hidden_states = True
|
||||
@ -832,6 +853,8 @@ class DistilBert(LanguageModel):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -840,20 +863,25 @@ class DistilBert(LanguageModel):
|
||||
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
# We need to manually aggregate that to get a pooled output (one vec per seq)
|
||||
pooled_output = self.pooler(output_tuple[0])
|
||||
if self.model.config.output_hidden_states == True:
|
||||
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
else:
|
||||
sequence_output = output_tuple[0]
|
||||
return sequence_output, pooled_output
|
||||
return (output_tuple[0], pooled_output) + output_tuple[1:]
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.config.output_hidden_states = True
|
||||
@ -921,6 +949,8 @@ class XLNet(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -932,26 +962,29 @@ class XLNet(LanguageModel):
|
||||
It is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
# Note: XLNet has a couple of special input tensors for pretraining / text generation (perm_mask, target_mapping ...)
|
||||
# We will need to implement them, if we wanna support LM adaptation
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
# XLNet also only returns the sequence_output (one vec per token)
|
||||
# We need to manually aggregate that to get a pooled output (one vec per seq)
|
||||
# TODO verify that this is really doing correct pooling
|
||||
pooled_output = self.pooler(output_tuple[0])
|
||||
|
||||
if self.model.output_hidden_states == True:
|
||||
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output, all_hidden_states
|
||||
else:
|
||||
sequence_output = output_tuple[0]
|
||||
return sequence_output, pooled_output
|
||||
return (output_tuple[0], pooled_output) + output_tuple[1:]
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.output_hidden_states = True
|
||||
@ -1030,6 +1063,8 @@ class Electra(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -1038,26 +1073,31 @@ class Electra(LanguageModel):
|
||||
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
return_dict=False
|
||||
)
|
||||
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions
|
||||
)
|
||||
# We need to manually aggregate that to get a pooled output (one vec per seq)
|
||||
pooled_output = self.pooler(output_tuple[0])
|
||||
|
||||
if self.model.config.output_hidden_states == True:
|
||||
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
else:
|
||||
sequence_output = output_tuple[0]
|
||||
return sequence_output, pooled_output
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.config.output_hidden_states = True
|
||||
return (output_tuple[0], pooled_output) + output_tuple[1:]
|
||||
|
||||
def disable_hidden_states_output(self):
|
||||
self.model.config.output_hidden_states = False
|
||||
@ -1439,10 +1479,12 @@ class BigBird(LanguageModel):
|
||||
input_ids: torch.Tensor,
|
||||
segment_ids: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Perform the forward pass of the BERT model.
|
||||
Perform the forward pass of the BigBird model.
|
||||
|
||||
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
|
||||
:param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
|
||||
@ -1450,19 +1492,24 @@ class BigBird(LanguageModel):
|
||||
It is a tensor of shape [batch_size, max_seq_len]
|
||||
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
|
||||
of shape [batch_size, max_seq_len]
|
||||
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
|
||||
:param output_attentions: Whether to output attentions in addition to the embeddings
|
||||
:return: Embeddings for each token in the input sequence.
|
||||
"""
|
||||
if output_hidden_states is None:
|
||||
output_hidden_states = self.model.encoder.config.output_hidden_states
|
||||
if output_attentions is None:
|
||||
output_attentions = self.model.encoder.config.output_attentions
|
||||
|
||||
output_tuple = self.model(
|
||||
input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=padding_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=False
|
||||
)
|
||||
if self.model.encoder.config.output_hidden_states == True:
|
||||
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
|
||||
return sequence_output, pooled_output, all_hidden_states
|
||||
else:
|
||||
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
|
||||
return sequence_output, pooled_output
|
||||
return output_tuple
|
||||
|
||||
def enable_hidden_states_output(self):
|
||||
self.model.encoder.config.output_hidden_states = True
|
||||
|
@ -188,8 +188,16 @@
|
||||
"source": [
|
||||
"# Downloading script\n",
|
||||
"!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py\n",
|
||||
"# Just replace the path with your dataset and adjust the output\n",
|
||||
"!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2"
|
||||
"\n",
|
||||
"# Downloading smaller glove vector file (only for demonstration purposes)\n",
|
||||
"!wget https://nlp.stanford.edu/data/glove.6B.zip\n",
|
||||
"!unzip glove.6B.zip\n",
|
||||
"\n",
|
||||
"# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)\n",
|
||||
"!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json\n",
|
||||
"\n",
|
||||
"# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)\n",
|
||||
"!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -217,7 +225,7 @@
|
||||
"# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.\n",
|
||||
"student = FARMReader(model_name_or_path=\"huawei-noah/TinyBERT_General_6L_768D\", use_gpu=True)\n",
|
||||
"\n",
|
||||
"student.distil_intermediate_layers_from(teacher, data_dir=\"data/squad20\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n",
|
||||
"student.distil_intermediate_layers_from(teacher, data_dir=\".\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n",
|
||||
"student.distil_prediction_layer_from(teacher, data_dir=\"data/squad20\", train_filename=\"dev-v2.0.json\", use_gpu=True)\n",
|
||||
"\n",
|
||||
"student.save(directory=\"my_distilled_model\")"
|
||||
|
@ -12,6 +12,7 @@ from haystack.utils import augment_squad
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
|
||||
def tutorial2_finetune_a_model_on_your_data():
|
||||
# ## Create Training Data
|
||||
@ -65,8 +66,16 @@ def distil():
|
||||
# ### Augmenting your training data
|
||||
# To get the most out of model distillation, we recommend increasing the size of your training data by using data augmentation.
|
||||
# You can do this by running the [`augment_squad.py` script](https://github.com/deepset-ai/haystack/blob/master/haystack/utils/augment_squad.py):
|
||||
# # Just replace dataset.json with the name of your dataset and adjust the output path
|
||||
augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2)
|
||||
|
||||
# Downloading smaller glove vector file (only for demonstration purposes)
|
||||
os.system("wget https://nlp.stanford.edu/data/glove.6B.zip")
|
||||
os.system("unzip glove.6B.zip")
|
||||
|
||||
# Downloading very small dataset to make tutorial faster (please use a bigger dataset in real use cases)
|
||||
os.system("wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json")
|
||||
|
||||
# Just replace dataset.json with the name of your dataset and adjust the output path
|
||||
augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2, glove_path=Path("glove.6B.300d.txt"))
|
||||
# In this case, we use a multiplication factor of 2 to keep this example lightweight.
|
||||
# Usually you would use a factor like 20 depending on the size of your training data.
|
||||
# Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.
|
||||
|
Loading…
x
Reference in New Issue
Block a user