diff --git a/pdelfin/train/molmo/modeling_molmo.py b/pdelfin/train/molmo/modeling_molmo.py index d68da04..87e687e 100644 --- a/pdelfin/train/molmo/modeling_molmo.py +++ b/pdelfin/train/molmo/modeling_molmo.py @@ -1311,7 +1311,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone): def __init__(self, config: FullMolmoConfig): super().__init__(config) v_cfg = self.config.vision_backbone - self.grad_checkpointing = False + self.grad_checkpointing = True self.num_prefix_tokens = self.image_vit.num_prefix_tokens assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported" @@ -1688,7 +1688,7 @@ class Molmo(nn.Module): "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning ) torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it + torch.backends.cuda.enable_mem_efficient_sdp(True) # this is super slow so make sure torch won't use it wte = None if self.config.additional_vocab_size is not None: @@ -1741,6 +1741,8 @@ class Molmo(nn.Module): self.__num_fwd_flops: Optional[int] = None + self.gradient_checkpointing = False + def reset_parameters(self): if self.vision_backbone is not None: self.vision_backbone.reset_parameters() @@ -1951,7 +1953,11 @@ class Molmo(nn.Module): all_hidden_states.append(x) layer_past = None if past_key_values is None else past_key_values[block_idx] - x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache) + + if self.gradient_checkpointing and self.training: + x, cache = self._gradient_checkpointing_func(block, x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache) + else: + x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache) if attn_key_values is not None: assert cache is not None @@ -2011,6 +2017,7 @@ class Molmo(nn.Module): class MolmoForCausalLM(PreTrainedModel): config_class = MolmoConfig + supports_gradient_checkpointing = True base_model_prefix = "model" _no_split_modules = ["MolmoBlock"]