diff --git a/docs/_src/tutorials/tutorials/2.md b/docs/_src/tutorials/tutorials/2.md index 48854daac..e8fc03a8d 100644 --- a/docs/_src/tutorials/tutorials/2.md +++ b/docs/_src/tutorials/tutorials/2.md @@ -104,8 +104,16 @@ To get the most out of model distillation, we recommend increasing the size of y ```python # Downloading script !wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py -# Just replace the path with your dataset and adjust the output -!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2 + +# Downloading smaller glove vector file (only for demonstration purposes) +!wget https://nlp.stanford.edu/data/glove.6B.zip +!unzip glove.6B.zip + +# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases) +!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json + +# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file) +!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt ``` In this case, we use a multiplication factor of 2 to keep this example lightweight. Usually you would use a factor like 20 depending on the size of your training data. Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU. @@ -124,7 +132,7 @@ teacher = FARMReader(model_name_or_path="my_model", use_gpu=True) # The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student. student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D", use_gpu=True) -student.distil_intermediate_layers_from(teacher, data_dir="data/squad20", train_filename="augmented_dataset.json", use_gpu=True) +student.distil_intermediate_layers_from(teacher, data_dir=".", train_filename="augmented_dataset.json", use_gpu=True) student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True) student.save(directory="my_distilled_model") diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index 875674670..89667f55e 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -572,6 +572,8 @@ class Albert(LanguageModel): input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ @@ -583,19 +585,24 @@ class Albert(LanguageModel): It is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + output_tuple = self.model( input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=False ) - if self.model.encoder.config.output_hidden_states == True: - sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2] - return sequence_output, pooled_output, all_hidden_states - else: - sequence_output, pooled_output = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output + return output_tuple def enable_hidden_states_output(self): self.model.encoder.config.output_hidden_states = True @@ -654,6 +661,8 @@ class Roberta(LanguageModel): input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ @@ -665,19 +674,24 @@ class Roberta(LanguageModel): It is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + output_tuple = self.model( input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=False ) - if self.model.encoder.config.output_hidden_states == True: - sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2] - return sequence_output, pooled_output, all_hidden_states - else: - sequence_output, pooled_output = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output + return output_tuple def enable_hidden_states_output(self): self.model.encoder.config.output_hidden_states = True @@ -736,6 +750,8 @@ class XLMRoberta(LanguageModel): input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ @@ -747,19 +763,24 @@ class XLMRoberta(LanguageModel): It is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + output_tuple = self.model( input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=False ) - if self.model.encoder.config.output_hidden_states == True: - sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2] - return sequence_output, pooled_output, all_hidden_states - else: - sequence_output, pooled_output = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output + return output_tuple def enable_hidden_states_output(self): self.model.encoder.config.output_hidden_states = True @@ -832,6 +853,8 @@ class DistilBert(LanguageModel): self, input_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ @@ -840,20 +863,25 @@ class DistilBert(LanguageModel): :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + output_tuple = self.model( input_ids, attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=False ) # We need to manually aggregate that to get a pooled output (one vec per seq) pooled_output = self.pooler(output_tuple[0]) - if self.model.config.output_hidden_states == True: - sequence_output, all_hidden_states = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output - else: - sequence_output = output_tuple[0] - return sequence_output, pooled_output + return (output_tuple[0], pooled_output) + output_tuple[1:] def enable_hidden_states_output(self): self.model.config.output_hidden_states = True @@ -921,6 +949,8 @@ class XLNet(LanguageModel): input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ @@ -932,26 +962,29 @@ class XLNet(LanguageModel): It is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + # Note: XLNet has a couple of special input tensors for pretraining / text generation (perm_mask, target_mapping ...) # We will need to implement them, if we wanna support LM adaptation output_tuple = self.model( input_ids, - token_type_ids=segment_ids, attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=False ) # XLNet also only returns the sequence_output (one vec per token) # We need to manually aggregate that to get a pooled output (one vec per seq) # TODO verify that this is really doing correct pooling pooled_output = self.pooler(output_tuple[0]) - - if self.model.output_hidden_states == True: - sequence_output, all_hidden_states = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output, all_hidden_states - else: - sequence_output = output_tuple[0] - return sequence_output, pooled_output + return (output_tuple[0], pooled_output) + output_tuple[1:] def enable_hidden_states_output(self): self.model.output_hidden_states = True @@ -1030,6 +1063,8 @@ class Electra(LanguageModel): input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ @@ -1038,26 +1073,31 @@ class Electra(LanguageModel): :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ output_tuple = self.model( input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, + return_dict=False ) + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + + output_tuple = self.model( + input_ids, + attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions + ) # We need to manually aggregate that to get a pooled output (one vec per seq) pooled_output = self.pooler(output_tuple[0]) - - if self.model.config.output_hidden_states == True: - sequence_output, all_hidden_states = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output - else: - sequence_output = output_tuple[0] - return sequence_output, pooled_output - - def enable_hidden_states_output(self): - self.model.config.output_hidden_states = True + return (output_tuple[0], pooled_output) + output_tuple[1:] def disable_hidden_states_output(self): self.model.config.output_hidden_states = False @@ -1439,10 +1479,12 @@ class BigBird(LanguageModel): input_ids: torch.Tensor, segment_ids: torch.Tensor, padding_mask: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): """ - Perform the forward pass of the BERT model. + Perform the forward pass of the BigBird model. :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len] :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the @@ -1450,19 +1492,24 @@ class BigBird(LanguageModel): It is a tensor of shape [batch_size, max_seq_len] :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens of shape [batch_size, max_seq_len] + :param output_hidden_states: Whether to output hidden states in addition to the embeddings + :param output_attentions: Whether to output attentions in addition to the embeddings :return: Embeddings for each token in the input sequence. """ + if output_hidden_states is None: + output_hidden_states = self.model.encoder.config.output_hidden_states + if output_attentions is None: + output_attentions = self.model.encoder.config.output_attentions + output_tuple = self.model( input_ids, token_type_ids=segment_ids, attention_mask=padding_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=False ) - if self.model.encoder.config.output_hidden_states == True: - sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2] - return sequence_output, pooled_output, all_hidden_states - else: - sequence_output, pooled_output = output_tuple[0], output_tuple[1] - return sequence_output, pooled_output + return output_tuple def enable_hidden_states_output(self): self.model.encoder.config.output_hidden_states = True diff --git a/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb b/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb index 1a86a22e6..4a327c115 100644 --- a/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb +++ b/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb @@ -188,8 +188,16 @@ "source": [ "# Downloading script\n", "!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py\n", - "# Just replace the path with your dataset and adjust the output\n", - "!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2" + "\n", + "# Downloading smaller glove vector file (only for demonstration purposes)\n", + "!wget https://nlp.stanford.edu/data/glove.6B.zip\n", + "!unzip glove.6B.zip\n", + "\n", + "# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)\n", + "!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json\n", + "\n", + "# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)\n", + "!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt" ] }, { @@ -217,7 +225,7 @@ "# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.\n", "student = FARMReader(model_name_or_path=\"huawei-noah/TinyBERT_General_6L_768D\", use_gpu=True)\n", "\n", - "student.distil_intermediate_layers_from(teacher, data_dir=\"data/squad20\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n", + "student.distil_intermediate_layers_from(teacher, data_dir=\".\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n", "student.distil_prediction_layer_from(teacher, data_dir=\"data/squad20\", train_filename=\"dev-v2.0.json\", use_gpu=True)\n", "\n", "student.save(directory=\"my_distilled_model\")" diff --git a/tutorials/Tutorial2_Finetune_a_model_on_your_data.py b/tutorials/Tutorial2_Finetune_a_model_on_your_data.py index a5eab4ba1..f45a1dd0b 100755 --- a/tutorials/Tutorial2_Finetune_a_model_on_your_data.py +++ b/tutorials/Tutorial2_Finetune_a_model_on_your_data.py @@ -12,6 +12,7 @@ from haystack.utils import augment_squad from pathlib import Path +import os def tutorial2_finetune_a_model_on_your_data(): # ## Create Training Data @@ -65,8 +66,16 @@ def distil(): # ### Augmenting your training data # To get the most out of model distillation, we recommend increasing the size of your training data by using data augmentation. # You can do this by running the [`augment_squad.py` script](https://github.com/deepset-ai/haystack/blob/master/haystack/utils/augment_squad.py): - # # Just replace dataset.json with the name of your dataset and adjust the output path - augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2) + + # Downloading smaller glove vector file (only for demonstration purposes) + os.system("wget https://nlp.stanford.edu/data/glove.6B.zip") + os.system("unzip glove.6B.zip") + + # Downloading very small dataset to make tutorial faster (please use a bigger dataset in real use cases) + os.system("wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json") + + # Just replace dataset.json with the name of your dataset and adjust the output path + augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2, glove_path=Path("glove.6B.300d.txt")) # In this case, we use a multiplication factor of 2 to keep this example lightweight. # Usually you would use a factor like 20 depending on the size of your training data. # Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.