Manually adding gradient checkpointing

This commit is contained in:
Jake Poznanski 2025-01-23 15:18:22 -08:00
parent 18569a4c63
commit f42bb02fce

View File

@ -1311,7 +1311,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
def __init__(self, config: FullMolmoConfig): def __init__(self, config: FullMolmoConfig):
super().__init__(config) super().__init__(config)
v_cfg = self.config.vision_backbone v_cfg = self.config.vision_backbone
self.grad_checkpointing = False self.grad_checkpointing = True
self.num_prefix_tokens = self.image_vit.num_prefix_tokens self.num_prefix_tokens = self.image_vit.num_prefix_tokens
assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported" assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
@ -1688,7 +1688,7 @@ class Molmo(nn.Module):
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
) )
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it torch.backends.cuda.enable_mem_efficient_sdp(True) # this is super slow so make sure torch won't use it
wte = None wte = None
if self.config.additional_vocab_size is not None: if self.config.additional_vocab_size is not None:
@ -1741,6 +1741,8 @@ class Molmo(nn.Module):
self.__num_fwd_flops: Optional[int] = None self.__num_fwd_flops: Optional[int] = None
self.gradient_checkpointing = False
def reset_parameters(self): def reset_parameters(self):
if self.vision_backbone is not None: if self.vision_backbone is not None:
self.vision_backbone.reset_parameters() self.vision_backbone.reset_parameters()
@ -1951,7 +1953,11 @@ class Molmo(nn.Module):
all_hidden_states.append(x) all_hidden_states.append(x)
layer_past = None if past_key_values is None else past_key_values[block_idx] layer_past = None if past_key_values is None else past_key_values[block_idx]
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
if self.gradient_checkpointing and self.training:
x, cache = self._gradient_checkpointing_func(block, x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
else:
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None: if attn_key_values is not None:
assert cache is not None assert cache is not None
@ -2011,6 +2017,7 @@ class Molmo(nn.Module):
class MolmoForCausalLM(PreTrainedModel): class MolmoForCausalLM(PreTrainedModel):
config_class = MolmoConfig config_class = MolmoConfig
supports_gradient_checkpointing = True
base_model_prefix = "model" base_model_prefix = "model"
_no_split_modules = ["MolmoBlock"] _no_split_modules = ["MolmoBlock"]