diff --git a/olmocr/train/molmo/modeling_molmo.py b/olmocr/train/molmo/modeling_molmo.py index caf60dd..e4d8460 100644 --- a/olmocr/train/molmo/modeling_molmo.py +++ b/olmocr/train/molmo/modeling_molmo.py @@ -1,3 +1,4 @@ +# type: ignore import logging import math from copy import deepcopy @@ -576,26 +577,24 @@ class Dropout(nn.Dropout): @dataclass class VisionBackboneConfig: - def __init__(self): - super().__init__() - self.image_default_input_size: Tuple[int, int] = (336, 336) - self.image_patch_size: int = 14 - self.image_pos_patch_size: int = 14 - self.image_emb_dim: int = 1024 - self.image_num_heads: int = 16 - self.image_num_key_value_heads: int = 16 - self.image_num_layers: int = 24 - self.image_head_dim: int = 64 - self.image_mlp_dim: int = 4096 - self.image_mlp_activations: str = "gelu" - self.image_dropout_rate: float = 0.0 - self.image_num_pos: int = 577 - self.image_norm_eps: float = 1e-5 - self.attention_dropout: float = 0.0 - self.residual_dropout: float = 0.0 - self.initializer_range: float = 0.02 - self.fsdp_wrap: bool = False - self.resize_mode: str = "default" + image_default_input_size: Tuple[int, int] = (336, 336) + image_patch_size: int = 14 + image_pos_patch_size: int = 14 + image_emb_dim: int = 1024 + image_num_heads: int = 16 + image_num_key_value_heads: int = 16 + image_num_layers: int = 24 + image_head_dim: int = 64 + image_mlp_dim: int = 4096 + image_mlp_activations: str = "gelu" + image_dropout_rate: float = 0.0 + image_num_pos: int = 577 + image_norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + residual_dropout: float = 0.0 + initializer_range: float = 0.02 + fsdp_wrap: bool = False + resize_mode: str = "default" def __post_init__(self): self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] @@ -608,61 +607,59 @@ class VisionBackboneConfig: @dataclass class FullMolmoConfig: - def __init__(self): - super().__init__() - self.d_model: int = 768 - self.n_heads: int = 12 - self.n_kv_heads: Optional[int] = None - self.qkv_bias: bool = False - self.clip_qkv: Optional[float] = None - self.n_layers: int = 12 - self.mlp_ratio: int = 4 - self.mlp_hidden_size: Optional[int] = None - self.activation_type: str = "swiglu" - self.block_group_size: int = 1 - self.rope: bool = True - self.rope_full_precision: bool = True - self.rope_theta: float = 10000.0 - self.rope_impl: str = "interleave" - self.vision_backbone: Optional[VisionBackboneConfig] = None - self.attention_type: str = "sdpa" - self.float32_attention: bool = True - self.attention_dropout: float = 0.1 - self.response_attention_dropout: float = 0.0 - self.multi_query_attention: Optional[bool] = None - self.attention_layer_norm: bool = False - self.residual_dropout: float = 0.1 - self.embedding_dropout: float = 0.1 - self.layer_norm_type: str = "default" - self.layer_norm_with_affine: bool = True - self.layer_norm_eps: Optional[float] = None - self.attention_layer_norm_with_affine: bool = True - self.max_sequence_length: int = 1024 - self.max_position_embeddings: Optional[int] = None - self.include_bias: bool = True - self.bias_for_layer_norm: Optional[bool] = None - self.scale_logits: bool = False - self.vocab_size: int = 50257 - self.embedding_size: Optional[int] = 50304 - self.additional_vocab_size: Optional[int] = None - self.new_embedding_init_range: float = 0.02 - self.weight_tying: bool = True - self.pad_token_id: int = -1 - self.init_device: Optional[str] = None - self.init_std: float = 0.02 - self.init_cutoff_factor: Optional[float] = None - self.norm_after: bool = False - self.precision: Optional[str] = None - self.image_padding_embed: Optional[str] = None - self.vit_layers: Tuple = (-1,) - self.image_pooling_h: int = 2 - self.image_pooling_w: int = 2 - self.image_pooling_2d: str = "attention" - self.image_projector: str = "mlp" - self.image_feature_dropout: float = 0.0 - self.initializer_range: float = 0.02 - self.normalize_input_embeds: bool = False - self.use_position_ids: bool = True + d_model: int = 768 + n_heads: int = 12 + n_kv_heads: Optional[int] = None + qkv_bias: bool = False + clip_qkv: Optional[float] = None + n_layers: int = 12 + mlp_ratio: int = 4 + mlp_hidden_size: Optional[int] = None + activation_type: str = "swiglu" + block_group_size: int = 1 + rope: bool = True + rope_full_precision: bool = True + rope_theta: float = 10000.0 + rope_impl: str = "interleave" + vision_backbone: Optional[VisionBackboneConfig] = None + attention_type: str = "sdpa" + float32_attention: bool = True + attention_dropout: float = 0.1 + response_attention_dropout: float = 0.0 + multi_query_attention: Optional[bool] = None + attention_layer_norm: bool = False + residual_dropout: float = 0.1 + embedding_dropout: float = 0.1 + layer_norm_type: str = "default" + layer_norm_with_affine: bool = True + layer_norm_eps: Optional[float] = None + attention_layer_norm_with_affine: bool = True + max_sequence_length: int = 1024 + max_position_embeddings: Optional[int] = None + include_bias: bool = True + bias_for_layer_norm: Optional[bool] = None + scale_logits: bool = False + vocab_size: int = 50257 + embedding_size: Optional[int] = 50304 + additional_vocab_size: Optional[int] = None + new_embedding_init_range: float = 0.02 + weight_tying: bool = True + pad_token_id: int = -1 + init_device: Optional[str] = None + init_std: float = 0.02 + init_cutoff_factor: Optional[float] = None + norm_after: bool = False + precision: Optional[str] = None + image_padding_embed: Optional[str] = None + vit_layers: Tuple = (-1,) + image_pooling_h: int = 2 + image_pooling_w: int = 2 + image_pooling_2d: str = "attention" + image_projector: str = "mlp" + image_feature_dropout: float = 0.0 + initializer_range: float = 0.02 + normalize_input_embeds: bool = False + use_position_ids: bool = True @property def effective_n_kv_heads(self) -> int: @@ -691,7 +688,7 @@ class FullMolmoConfig: @property def image_patch_size(self): assert self.vision_backbone is not None - return self.vision_backbone.image_patch_size + return self.visoin_backbone.image_patch_size def llm_patches_per_crop(self): h, w = self.image_num_patch @@ -709,7 +706,7 @@ class ViTMLP(nn.Module): def __init__(self, config: FullMolmoConfig): super().__init__() self.config = config - v_cfg = config.vision_backbone or VisionBackboneConfig() + v_cfg = config.vision_backbone self.w1 = nn.Linear( v_cfg.image_emb_dim, @@ -729,7 +726,7 @@ class ViTMLP(nn.Module): ) def reset_parameters(self): - v_cfg = self.config.vision_backbone or VisionBackboneConfig() + v_cfg = self.config.vision_backbone nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0) nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0) nn.init.zeros_(self.w1.bias) @@ -748,7 +745,7 @@ class ResidualAttentionBlock(nn.Module): super().__init__() self.config = config - v_cfg = config.vision_backbone or VisionBackboneConfig() + v_cfg = config.vision_backbone self.attention = MultiHeadDotProductAttention(config) self.feed_forward = ViTMLP(config) self.attention_norm = nn.LayerNorm( @@ -781,7 +778,7 @@ class BlockCollection(nn.Module): self.config = config self.grad_checkpointing: bool = False - v_cfg = config.vision_backbone or VisionBackboneConfig() + v_cfg = config.vision_backbone self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)]) def reset_parameters(self): @@ -809,7 +806,7 @@ class VisionTransformer(nn.Module): super().__init__() self.config = config - v_cfg = config.vision_backbone or VisionBackboneConfig() + v_cfg = config.vision_backbone # class embeddings and positional embeddings self.scale = v_cfg.image_emb_dim**-0.5 self.class_embedding = nn.Parameter( @@ -852,15 +849,15 @@ class VisionTransformer(nn.Module): pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) - (patch_num_0, patch_num_1) = patch_num # type: ignore + (patch_num_0, patch_num_1) = patch_num - if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # type: ignore + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py # antialias: default True in jax.image.resize pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = F.interpolate( pos_emb, - size=(patch_num_0, patch_num_1), # type: ignore + size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, @@ -871,12 +868,12 @@ class VisionTransformer(nn.Module): x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) return x - def forward(self, x: torch.Tensor, patch_num: Optional[int] = None) -> List[torch.Tensor]: + def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ if patch_num is None: - patch_num = self.config.vision_backbone.image_num_patch # type: ignore + patch_num = self.config.vision_backbone.image_num_patch B, N, D = x.shape x = self.patch_embedding(x) @@ -897,7 +894,7 @@ class MultiHeadDotProductAttention(nn.Module): self.config = config self.use_bias = use_bias - v_cfg = config.vision_backbone or VisionBackboneConfig() + v_cfg = config.vision_backbone self.embed_dim = v_cfg.image_emb_dim self.num_heads = v_cfg.image_num_heads self.head_dim = v_cfg.image_head_dim @@ -989,12 +986,12 @@ class MultiHeadDotProductAttention(nn.Module): elif self.config.attention_type == "sdpa": if self.config.float32_attention and not torch.is_autocast_enabled(): xv = xv.to(torch.float32) - attn_output = F.scaled_dot_product_attention( # type: ignore + attn_output = F.scaled_dot_product_attention( xq.transpose(1, 2).contiguous(), xk.transpose(1, 2).contiguous(), xv.transpose(1, 2).contiguous(), is_causal=False, - dropout_p=self.config.vision_backbone.attention_dropout, # type: ignore + dropout_p=self.config.vision_backbone.attention_dropout, ).transpose(1, 2) else: raise NotImplementedError(self.config.attention_type) @@ -1027,7 +1024,7 @@ class MultiHeadAttentionPool(nn.Module): self.mean_residual = mean_residual self.query = query - v_cfg = config.vision_backbone or VisionBackboneConfig() + v_cfg = config.vision_backbone input_dim = v_cfg.image_emb_dim self.embed_dim = v_cfg.image_emb_dim * factor self.num_heads = v_cfg.image_num_heads @@ -1206,17 +1203,18 @@ class OLMoVisionBackbone(nn.Module): super().__init__() self.config = config self.image_vit = VisionTransformer(config) - input_dim: Optional[int] = None + + input_dim: int = None self.image_pooling_2d: nn.Module = None if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}: self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False) - input_dim = config.vision_backbone.image_emb_dim # type: ignore + input_dim = config.vision_backbone.image_emb_dim elif config.image_pooling_2d == ImagePooling2DType.attention_2wide: cfg = deepcopy(config) - cfg.vision_backbone.image_emb_dim *= 2 # type: ignore - cfg.vision_backbone.image_head_dim *= 2 # type: ignore + cfg.vision_backbone.image_emb_dim *= 2 + cfg.vision_backbone.image_head_dim *= 2 self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False) - input_dim = cfg.vision_backbone.image_emb_dim # type: ignore + input_dim = cfg.vision_backbone.image_emb_dim elif config.image_pooling_2d == ImagePooling2DType.attention_v2: assert config.vit_layers is not None use_bias = True @@ -1235,11 +1233,11 @@ class OLMoVisionBackbone(nn.Module): query=query, is_vit_layer=False, ) - input_dim = config.vision_backbone.image_emb_dim * factor # type: ignore + input_dim = config.vision_backbone.image_emb_dim * factor elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]: self.image_pooling_2d = None nlayers = 1 if config.vit_layers is None else len(config.vit_layers) - input_dim = nlayers * config.vision_backbone.image_emb_dim # type: ignore + input_dim = nlayers * config.vision_backbone.image_emb_dim else: raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}") @@ -1247,9 +1245,9 @@ class OLMoVisionBackbone(nn.Module): # `MLP` assume the activation takes two inputs, so it must be a 'llama' version if config.activation_type == ActivationType.swiglu: - mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) # type: ignore + mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) elif config.activation_type == ActivationType.gelu: - mlp_config = replace(config, activation_type=ActivationType.llama_geglu) # type: ignore + mlp_config = replace(config, activation_type=ActivationType.llama_geglu) else: mlp_config = config if config.image_projector == ImageProjectType.mlpx2: @@ -1294,7 +1292,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone): self.pad_embed = None if config.image_padding_embed: - image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) # type: ignore + image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) if config.image_padding_embed in ["pad_embed", "regress"]: self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device)) elif config.image_padding_embed == "pad_and_partial_pad": @@ -1352,13 +1350,13 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone): assert image_masks is not None if cfg.image_padding_embed == "pad_embed": all_pad = (image_masks == 0).to(dtype=torch.float32) - pad_embed = self.pad_embed[None, None, None, :] # type: ignore + pad_embed = self.pad_embed[None, None, None, :] image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1) elif cfg.image_padding_embed == "regress": - pad_embed = self.pad_embed[None, None, None, :] # type: ignore + pad_embed = self.pad_embed[None, None, None, :] image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1) elif cfg.image_padding_embed == "pad_and_partial_pad": - pad_embed = self.pad_embed[:, None, None, None, :] # type: ignore + pad_embed = self.pad_embed[:, None, None, None, :] all_pad = image_masks == 0 partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype) all_pad = all_pad.to(dtype=image_features.dtype) @@ -1560,12 +1558,12 @@ class LayerNormBase(nn.Module): self.eps = self.config.layer_norm_eps or eps self.normalized_shape = (size or config.d_model,) if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): - self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) # type: ignore + self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) use_bias = self.config.bias_for_layer_norm if use_bias is None: use_bias = self.config.include_bias if use_bias: - self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) # type: ignore + self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) else: self.register_parameter("bias", None) else: @@ -1596,7 +1594,7 @@ class RMSLayerNorm(LayerNormBase): elementwise_affine: Optional[bool] = None, eps: float = 1e-5, ): - super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.autocast(enabled=False, device_type=x.device.type): @@ -1628,7 +1626,7 @@ class LayerNorm(LayerNormBase): elementwise_affine: Optional[bool] = None, eps: float = 1e-05, ): - super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) self.low_precision = low_precision def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1666,7 +1664,7 @@ class Molmo(nn.Module): if self.config.additional_vocab_size is not None: wte = Embedding( config.embedding_size or config.vocab_size, - config.additional_vocab_size, # type: ignore + config.additional_vocab_size, config.d_model, device=config.init_device, initializer_range=config.initializer_range, @@ -1683,7 +1681,7 @@ class Molmo(nn.Module): ) ) - blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] # type: ignore + blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] if self.config.block_group_size > 1: raise NotImplementedError() else: @@ -1807,14 +1805,14 @@ class Molmo(nn.Module): if self.config.use_position_ids and attention_mask is None: attention_mask = input_ids != -1 - if subsegment_ids is not None and attention_mask is not None: + if subsegment_ids is not None: assert not use_cache, "Subsegment_ids cannot be used with cache." subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1) attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1) if position_ids is None: raise ValueError("Positioned ids must be given if using subsegment_ids") else: - if self.config.use_position_ids and position_ids is None and attention_mask is not None: + if self.config.use_position_ids and position_ids is None: position_ids = torch.clamp( torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0, @@ -1827,10 +1825,10 @@ class Molmo(nn.Module): x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore num_image: Optional[int] = None - if images is not None and image_input_idx is not None: + if images is not None: # shape: (batch_size, num_image, num_patch, d_model) # cls_embed: (batch_size, num_image, d_model) - image_features, cls_embed = self.vision_backbone(images, image_masks) # type: ignore + image_features, cls_embed = self.vision_backbone(images, image_masks) num_image, num_patch = image_features.shape[1:3] assert image_input_idx.shape == (batch_size, num_image, num_patch) @@ -2011,8 +2009,8 @@ class MolmoForCausalLM(PreTrainedModel): rope_theta=config.rope_theta, layer_norm_eps=config.layer_norm_eps, layer_norm_type=config.layer_norm_type, - vit_layers=[-2, -9], # type: ignore - vision_backbone=VisionBackboneConfig( # type: ignore + vit_layers=[-2, -9], + vision_backbone=VisionBackboneConfig( image_default_input_size=(336, 336), image_patch_size=14, image_pos_patch_size=14, @@ -2056,7 +2054,7 @@ class MolmoForCausalLM(PreTrainedModel): output_hidden_states: Optional[bool] = None, append_last_valid_logits: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, - cache_position: Optional[ # type: ignore + cache_position: Optional[ Cache ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -2082,7 +2080,7 @@ class MolmoForCausalLM(PreTrainedModel): position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, - last_logits_only=last_logits_only, # type: ignore + last_logits_only=last_logits_only, output_hidden_states=output_hidden_states, append_last_valid_logits=append_last_valid_logits, ) @@ -2156,7 +2154,7 @@ class MolmoForCausalLM(PreTrainedModel): input_ids = batch["input_ids"] batch_size, seq_len = input_ids.shape attention_mask = batch.get("attention_mask", None) - max_new_tokens = generation_config.max_new_tokens # type: ignore + max_new_tokens = generation_config.max_new_tokens assert max_new_tokens is not None mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len position_ids: Optional[torch.Tensor] = None