mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-18 13:52:17 +00:00
Manually adding gradient checkpointing
This commit is contained in:
parent
18569a4c63
commit
f42bb02fce
@ -1311,7 +1311,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
||||
def __init__(self, config: FullMolmoConfig):
|
||||
super().__init__(config)
|
||||
v_cfg = self.config.vision_backbone
|
||||
self.grad_checkpointing = False
|
||||
self.grad_checkpointing = True
|
||||
|
||||
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
||||
assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
|
||||
@ -1688,7 +1688,7 @@ class Molmo(nn.Module):
|
||||
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True) # this is super slow so make sure torch won't use it
|
||||
|
||||
wte = None
|
||||
if self.config.additional_vocab_size is not None:
|
||||
@ -1741,6 +1741,8 @@ class Molmo(nn.Module):
|
||||
|
||||
self.__num_fwd_flops: Optional[int] = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.vision_backbone is not None:
|
||||
self.vision_backbone.reset_parameters()
|
||||
@ -1951,6 +1953,10 @@ class Molmo(nn.Module):
|
||||
all_hidden_states.append(x)
|
||||
|
||||
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
x, cache = self._gradient_checkpointing_func(block, x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
|
||||
else:
|
||||
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
|
||||
|
||||
if attn_key_values is not None:
|
||||
@ -2011,6 +2017,7 @@ class Molmo(nn.Module):
|
||||
|
||||
class MolmoForCausalLM(PreTrainedModel):
|
||||
config_class = MolmoConfig
|
||||
supports_gradient_checkpointing = True
|
||||
base_model_prefix = "model"
|
||||
_no_split_modules = ["MolmoBlock"]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user