mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-18 13:52:17 +00:00
restored modeling_molmo.py
file
This commit is contained in:
parent
4bff92053b
commit
f57c6f3f7b
@ -1,3 +1,4 @@
|
|||||||
|
# type: ignore
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -576,26 +577,24 @@ class Dropout(nn.Dropout):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionBackboneConfig:
|
class VisionBackboneConfig:
|
||||||
def __init__(self):
|
image_default_input_size: Tuple[int, int] = (336, 336)
|
||||||
super().__init__()
|
image_patch_size: int = 14
|
||||||
self.image_default_input_size: Tuple[int, int] = (336, 336)
|
image_pos_patch_size: int = 14
|
||||||
self.image_patch_size: int = 14
|
image_emb_dim: int = 1024
|
||||||
self.image_pos_patch_size: int = 14
|
image_num_heads: int = 16
|
||||||
self.image_emb_dim: int = 1024
|
image_num_key_value_heads: int = 16
|
||||||
self.image_num_heads: int = 16
|
image_num_layers: int = 24
|
||||||
self.image_num_key_value_heads: int = 16
|
image_head_dim: int = 64
|
||||||
self.image_num_layers: int = 24
|
image_mlp_dim: int = 4096
|
||||||
self.image_head_dim: int = 64
|
image_mlp_activations: str = "gelu"
|
||||||
self.image_mlp_dim: int = 4096
|
image_dropout_rate: float = 0.0
|
||||||
self.image_mlp_activations: str = "gelu"
|
image_num_pos: int = 577
|
||||||
self.image_dropout_rate: float = 0.0
|
image_norm_eps: float = 1e-5
|
||||||
self.image_num_pos: int = 577
|
attention_dropout: float = 0.0
|
||||||
self.image_norm_eps: float = 1e-5
|
residual_dropout: float = 0.0
|
||||||
self.attention_dropout: float = 0.0
|
initializer_range: float = 0.02
|
||||||
self.residual_dropout: float = 0.0
|
fsdp_wrap: bool = False
|
||||||
self.initializer_range: float = 0.02
|
resize_mode: str = "default"
|
||||||
self.fsdp_wrap: bool = False
|
|
||||||
self.resize_mode: str = "default"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
||||||
@ -608,61 +607,59 @@ class VisionBackboneConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FullMolmoConfig:
|
class FullMolmoConfig:
|
||||||
def __init__(self):
|
d_model: int = 768
|
||||||
super().__init__()
|
n_heads: int = 12
|
||||||
self.d_model: int = 768
|
n_kv_heads: Optional[int] = None
|
||||||
self.n_heads: int = 12
|
qkv_bias: bool = False
|
||||||
self.n_kv_heads: Optional[int] = None
|
clip_qkv: Optional[float] = None
|
||||||
self.qkv_bias: bool = False
|
n_layers: int = 12
|
||||||
self.clip_qkv: Optional[float] = None
|
mlp_ratio: int = 4
|
||||||
self.n_layers: int = 12
|
mlp_hidden_size: Optional[int] = None
|
||||||
self.mlp_ratio: int = 4
|
activation_type: str = "swiglu"
|
||||||
self.mlp_hidden_size: Optional[int] = None
|
block_group_size: int = 1
|
||||||
self.activation_type: str = "swiglu"
|
rope: bool = True
|
||||||
self.block_group_size: int = 1
|
rope_full_precision: bool = True
|
||||||
self.rope: bool = True
|
rope_theta: float = 10000.0
|
||||||
self.rope_full_precision: bool = True
|
rope_impl: str = "interleave"
|
||||||
self.rope_theta: float = 10000.0
|
vision_backbone: Optional[VisionBackboneConfig] = None
|
||||||
self.rope_impl: str = "interleave"
|
attention_type: str = "sdpa"
|
||||||
self.vision_backbone: Optional[VisionBackboneConfig] = None
|
float32_attention: bool = True
|
||||||
self.attention_type: str = "sdpa"
|
attention_dropout: float = 0.1
|
||||||
self.float32_attention: bool = True
|
response_attention_dropout: float = 0.0
|
||||||
self.attention_dropout: float = 0.1
|
multi_query_attention: Optional[bool] = None
|
||||||
self.response_attention_dropout: float = 0.0
|
attention_layer_norm: bool = False
|
||||||
self.multi_query_attention: Optional[bool] = None
|
residual_dropout: float = 0.1
|
||||||
self.attention_layer_norm: bool = False
|
embedding_dropout: float = 0.1
|
||||||
self.residual_dropout: float = 0.1
|
layer_norm_type: str = "default"
|
||||||
self.embedding_dropout: float = 0.1
|
layer_norm_with_affine: bool = True
|
||||||
self.layer_norm_type: str = "default"
|
layer_norm_eps: Optional[float] = None
|
||||||
self.layer_norm_with_affine: bool = True
|
attention_layer_norm_with_affine: bool = True
|
||||||
self.layer_norm_eps: Optional[float] = None
|
max_sequence_length: int = 1024
|
||||||
self.attention_layer_norm_with_affine: bool = True
|
max_position_embeddings: Optional[int] = None
|
||||||
self.max_sequence_length: int = 1024
|
include_bias: bool = True
|
||||||
self.max_position_embeddings: Optional[int] = None
|
bias_for_layer_norm: Optional[bool] = None
|
||||||
self.include_bias: bool = True
|
scale_logits: bool = False
|
||||||
self.bias_for_layer_norm: Optional[bool] = None
|
vocab_size: int = 50257
|
||||||
self.scale_logits: bool = False
|
embedding_size: Optional[int] = 50304
|
||||||
self.vocab_size: int = 50257
|
additional_vocab_size: Optional[int] = None
|
||||||
self.embedding_size: Optional[int] = 50304
|
new_embedding_init_range: float = 0.02
|
||||||
self.additional_vocab_size: Optional[int] = None
|
weight_tying: bool = True
|
||||||
self.new_embedding_init_range: float = 0.02
|
pad_token_id: int = -1
|
||||||
self.weight_tying: bool = True
|
init_device: Optional[str] = None
|
||||||
self.pad_token_id: int = -1
|
init_std: float = 0.02
|
||||||
self.init_device: Optional[str] = None
|
init_cutoff_factor: Optional[float] = None
|
||||||
self.init_std: float = 0.02
|
norm_after: bool = False
|
||||||
self.init_cutoff_factor: Optional[float] = None
|
precision: Optional[str] = None
|
||||||
self.norm_after: bool = False
|
image_padding_embed: Optional[str] = None
|
||||||
self.precision: Optional[str] = None
|
vit_layers: Tuple = (-1,)
|
||||||
self.image_padding_embed: Optional[str] = None
|
image_pooling_h: int = 2
|
||||||
self.vit_layers: Tuple = (-1,)
|
image_pooling_w: int = 2
|
||||||
self.image_pooling_h: int = 2
|
image_pooling_2d: str = "attention"
|
||||||
self.image_pooling_w: int = 2
|
image_projector: str = "mlp"
|
||||||
self.image_pooling_2d: str = "attention"
|
image_feature_dropout: float = 0.0
|
||||||
self.image_projector: str = "mlp"
|
initializer_range: float = 0.02
|
||||||
self.image_feature_dropout: float = 0.0
|
normalize_input_embeds: bool = False
|
||||||
self.initializer_range: float = 0.02
|
use_position_ids: bool = True
|
||||||
self.normalize_input_embeds: bool = False
|
|
||||||
self.use_position_ids: bool = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def effective_n_kv_heads(self) -> int:
|
def effective_n_kv_heads(self) -> int:
|
||||||
@ -691,7 +688,7 @@ class FullMolmoConfig:
|
|||||||
@property
|
@property
|
||||||
def image_patch_size(self):
|
def image_patch_size(self):
|
||||||
assert self.vision_backbone is not None
|
assert self.vision_backbone is not None
|
||||||
return self.vision_backbone.image_patch_size
|
return self.visoin_backbone.image_patch_size
|
||||||
|
|
||||||
def llm_patches_per_crop(self):
|
def llm_patches_per_crop(self):
|
||||||
h, w = self.image_num_patch
|
h, w = self.image_num_patch
|
||||||
@ -709,7 +706,7 @@ class ViTMLP(nn.Module):
|
|||||||
def __init__(self, config: FullMolmoConfig):
|
def __init__(self, config: FullMolmoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
v_cfg = config.vision_backbone
|
||||||
|
|
||||||
self.w1 = nn.Linear(
|
self.w1 = nn.Linear(
|
||||||
v_cfg.image_emb_dim,
|
v_cfg.image_emb_dim,
|
||||||
@ -729,7 +726,7 @@ class ViTMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
v_cfg = self.config.vision_backbone or VisionBackboneConfig()
|
v_cfg = self.config.vision_backbone
|
||||||
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
|
||||||
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
|
||||||
nn.init.zeros_(self.w1.bias)
|
nn.init.zeros_(self.w1.bias)
|
||||||
@ -748,7 +745,7 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
v_cfg = config.vision_backbone
|
||||||
self.attention = MultiHeadDotProductAttention(config)
|
self.attention = MultiHeadDotProductAttention(config)
|
||||||
self.feed_forward = ViTMLP(config)
|
self.feed_forward = ViTMLP(config)
|
||||||
self.attention_norm = nn.LayerNorm(
|
self.attention_norm = nn.LayerNorm(
|
||||||
@ -781,7 +778,7 @@ class BlockCollection(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.grad_checkpointing: bool = False
|
self.grad_checkpointing: bool = False
|
||||||
|
|
||||||
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
v_cfg = config.vision_backbone
|
||||||
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
|
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@ -809,7 +806,7 @@ class VisionTransformer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
v_cfg = config.vision_backbone
|
||||||
# class embeddings and positional embeddings
|
# class embeddings and positional embeddings
|
||||||
self.scale = v_cfg.image_emb_dim**-0.5
|
self.scale = v_cfg.image_emb_dim**-0.5
|
||||||
self.class_embedding = nn.Parameter(
|
self.class_embedding = nn.Parameter(
|
||||||
@ -852,15 +849,15 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
||||||
|
|
||||||
(patch_num_0, patch_num_1) = patch_num # type: ignore
|
(patch_num_0, patch_num_1) = patch_num
|
||||||
|
|
||||||
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # type: ignore
|
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
||||||
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||||
# antialias: default True in jax.image.resize
|
# antialias: default True in jax.image.resize
|
||||||
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
pos_emb = F.interpolate(
|
pos_emb = F.interpolate(
|
||||||
pos_emb,
|
pos_emb,
|
||||||
size=(patch_num_0, patch_num_1), # type: ignore
|
size=(patch_num_0, patch_num_1),
|
||||||
mode="bicubic",
|
mode="bicubic",
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
@ -871,12 +868,12 @@ class VisionTransformer(nn.Module):
|
|||||||
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
|
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, patch_num: Optional[int] = None) -> List[torch.Tensor]:
|
def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
: param x: (batch_size, num_patch, n_pixels)
|
: param x: (batch_size, num_patch, n_pixels)
|
||||||
"""
|
"""
|
||||||
if patch_num is None:
|
if patch_num is None:
|
||||||
patch_num = self.config.vision_backbone.image_num_patch # type: ignore
|
patch_num = self.config.vision_backbone.image_num_patch
|
||||||
B, N, D = x.shape
|
B, N, D = x.shape
|
||||||
|
|
||||||
x = self.patch_embedding(x)
|
x = self.patch_embedding(x)
|
||||||
@ -897,7 +894,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
|
|
||||||
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
v_cfg = config.vision_backbone
|
||||||
self.embed_dim = v_cfg.image_emb_dim
|
self.embed_dim = v_cfg.image_emb_dim
|
||||||
self.num_heads = v_cfg.image_num_heads
|
self.num_heads = v_cfg.image_num_heads
|
||||||
self.head_dim = v_cfg.image_head_dim
|
self.head_dim = v_cfg.image_head_dim
|
||||||
@ -989,12 +986,12 @@ class MultiHeadDotProductAttention(nn.Module):
|
|||||||
elif self.config.attention_type == "sdpa":
|
elif self.config.attention_type == "sdpa":
|
||||||
if self.config.float32_attention and not torch.is_autocast_enabled():
|
if self.config.float32_attention and not torch.is_autocast_enabled():
|
||||||
xv = xv.to(torch.float32)
|
xv = xv.to(torch.float32)
|
||||||
attn_output = F.scaled_dot_product_attention( # type: ignore
|
attn_output = F.scaled_dot_product_attention(
|
||||||
xq.transpose(1, 2).contiguous(),
|
xq.transpose(1, 2).contiguous(),
|
||||||
xk.transpose(1, 2).contiguous(),
|
xk.transpose(1, 2).contiguous(),
|
||||||
xv.transpose(1, 2).contiguous(),
|
xv.transpose(1, 2).contiguous(),
|
||||||
is_causal=False,
|
is_causal=False,
|
||||||
dropout_p=self.config.vision_backbone.attention_dropout, # type: ignore
|
dropout_p=self.config.vision_backbone.attention_dropout,
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.config.attention_type)
|
raise NotImplementedError(self.config.attention_type)
|
||||||
@ -1027,7 +1024,7 @@ class MultiHeadAttentionPool(nn.Module):
|
|||||||
self.mean_residual = mean_residual
|
self.mean_residual = mean_residual
|
||||||
self.query = query
|
self.query = query
|
||||||
|
|
||||||
v_cfg = config.vision_backbone or VisionBackboneConfig()
|
v_cfg = config.vision_backbone
|
||||||
input_dim = v_cfg.image_emb_dim
|
input_dim = v_cfg.image_emb_dim
|
||||||
self.embed_dim = v_cfg.image_emb_dim * factor
|
self.embed_dim = v_cfg.image_emb_dim * factor
|
||||||
self.num_heads = v_cfg.image_num_heads
|
self.num_heads = v_cfg.image_num_heads
|
||||||
@ -1206,17 +1203,18 @@ class OLMoVisionBackbone(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.image_vit = VisionTransformer(config)
|
self.image_vit = VisionTransformer(config)
|
||||||
input_dim: Optional[int] = None
|
|
||||||
|
input_dim: int = None
|
||||||
self.image_pooling_2d: nn.Module = None
|
self.image_pooling_2d: nn.Module = None
|
||||||
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
|
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
|
||||||
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
|
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
|
||||||
input_dim = config.vision_backbone.image_emb_dim # type: ignore
|
input_dim = config.vision_backbone.image_emb_dim
|
||||||
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
|
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
|
||||||
cfg = deepcopy(config)
|
cfg = deepcopy(config)
|
||||||
cfg.vision_backbone.image_emb_dim *= 2 # type: ignore
|
cfg.vision_backbone.image_emb_dim *= 2
|
||||||
cfg.vision_backbone.image_head_dim *= 2 # type: ignore
|
cfg.vision_backbone.image_head_dim *= 2
|
||||||
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
|
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
|
||||||
input_dim = cfg.vision_backbone.image_emb_dim # type: ignore
|
input_dim = cfg.vision_backbone.image_emb_dim
|
||||||
elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
|
elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
|
||||||
assert config.vit_layers is not None
|
assert config.vit_layers is not None
|
||||||
use_bias = True
|
use_bias = True
|
||||||
@ -1235,11 +1233,11 @@ class OLMoVisionBackbone(nn.Module):
|
|||||||
query=query,
|
query=query,
|
||||||
is_vit_layer=False,
|
is_vit_layer=False,
|
||||||
)
|
)
|
||||||
input_dim = config.vision_backbone.image_emb_dim * factor # type: ignore
|
input_dim = config.vision_backbone.image_emb_dim * factor
|
||||||
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
|
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
|
||||||
self.image_pooling_2d = None
|
self.image_pooling_2d = None
|
||||||
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
|
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
|
||||||
input_dim = nlayers * config.vision_backbone.image_emb_dim # type: ignore
|
input_dim = nlayers * config.vision_backbone.image_emb_dim
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
|
||||||
|
|
||||||
@ -1247,9 +1245,9 @@ class OLMoVisionBackbone(nn.Module):
|
|||||||
|
|
||||||
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
|
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
|
||||||
if config.activation_type == ActivationType.swiglu:
|
if config.activation_type == ActivationType.swiglu:
|
||||||
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu) # type: ignore
|
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu)
|
||||||
elif config.activation_type == ActivationType.gelu:
|
elif config.activation_type == ActivationType.gelu:
|
||||||
mlp_config = replace(config, activation_type=ActivationType.llama_geglu) # type: ignore
|
mlp_config = replace(config, activation_type=ActivationType.llama_geglu)
|
||||||
else:
|
else:
|
||||||
mlp_config = config
|
mlp_config = config
|
||||||
if config.image_projector == ImageProjectType.mlpx2:
|
if config.image_projector == ImageProjectType.mlpx2:
|
||||||
@ -1294,7 +1292,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|||||||
|
|
||||||
self.pad_embed = None
|
self.pad_embed = None
|
||||||
if config.image_padding_embed:
|
if config.image_padding_embed:
|
||||||
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers) # type: ignore
|
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers)
|
||||||
if config.image_padding_embed in ["pad_embed", "regress"]:
|
if config.image_padding_embed in ["pad_embed", "regress"]:
|
||||||
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
|
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
|
||||||
elif config.image_padding_embed == "pad_and_partial_pad":
|
elif config.image_padding_embed == "pad_and_partial_pad":
|
||||||
@ -1352,13 +1350,13 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|||||||
assert image_masks is not None
|
assert image_masks is not None
|
||||||
if cfg.image_padding_embed == "pad_embed":
|
if cfg.image_padding_embed == "pad_embed":
|
||||||
all_pad = (image_masks == 0).to(dtype=torch.float32)
|
all_pad = (image_masks == 0).to(dtype=torch.float32)
|
||||||
pad_embed = self.pad_embed[None, None, None, :] # type: ignore
|
pad_embed = self.pad_embed[None, None, None, :]
|
||||||
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
|
||||||
elif cfg.image_padding_embed == "regress":
|
elif cfg.image_padding_embed == "regress":
|
||||||
pad_embed = self.pad_embed[None, None, None, :] # type: ignore
|
pad_embed = self.pad_embed[None, None, None, :]
|
||||||
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
|
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
|
||||||
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
elif cfg.image_padding_embed == "pad_and_partial_pad":
|
||||||
pad_embed = self.pad_embed[:, None, None, None, :] # type: ignore
|
pad_embed = self.pad_embed[:, None, None, None, :]
|
||||||
all_pad = image_masks == 0
|
all_pad = image_masks == 0
|
||||||
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
|
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
|
||||||
all_pad = all_pad.to(dtype=image_features.dtype)
|
all_pad = all_pad.to(dtype=image_features.dtype)
|
||||||
@ -1560,12 +1558,12 @@ class LayerNormBase(nn.Module):
|
|||||||
self.eps = self.config.layer_norm_eps or eps
|
self.eps = self.config.layer_norm_eps or eps
|
||||||
self.normalized_shape = (size or config.d_model,)
|
self.normalized_shape = (size or config.d_model,)
|
||||||
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
||||||
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device)) # type: ignore
|
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device))
|
||||||
use_bias = self.config.bias_for_layer_norm
|
use_bias = self.config.bias_for_layer_norm
|
||||||
if use_bias is None:
|
if use_bias is None:
|
||||||
use_bias = self.config.include_bias
|
use_bias = self.config.include_bias
|
||||||
if use_bias:
|
if use_bias:
|
||||||
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device)) # type: ignore
|
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device))
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
else:
|
else:
|
||||||
@ -1596,7 +1594,7 @@ class RMSLayerNorm(LayerNormBase):
|
|||||||
elementwise_affine: Optional[bool] = None,
|
elementwise_affine: Optional[bool] = None,
|
||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
):
|
):
|
||||||
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore
|
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
with torch.autocast(enabled=False, device_type=x.device.type):
|
with torch.autocast(enabled=False, device_type=x.device.type):
|
||||||
@ -1628,7 +1626,7 @@ class LayerNorm(LayerNormBase):
|
|||||||
elementwise_affine: Optional[bool] = None,
|
elementwise_affine: Optional[bool] = None,
|
||||||
eps: float = 1e-05,
|
eps: float = 1e-05,
|
||||||
):
|
):
|
||||||
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) # type: ignore
|
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
||||||
self.low_precision = low_precision
|
self.low_precision = low_precision
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -1666,7 +1664,7 @@ class Molmo(nn.Module):
|
|||||||
if self.config.additional_vocab_size is not None:
|
if self.config.additional_vocab_size is not None:
|
||||||
wte = Embedding(
|
wte = Embedding(
|
||||||
config.embedding_size or config.vocab_size,
|
config.embedding_size or config.vocab_size,
|
||||||
config.additional_vocab_size, # type: ignore
|
config.additional_vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
device=config.init_device,
|
device=config.init_device,
|
||||||
initializer_range=config.initializer_range,
|
initializer_range=config.initializer_range,
|
||||||
@ -1683,7 +1681,7 @@ class Molmo(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)] # type: ignore
|
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
||||||
if self.config.block_group_size > 1:
|
if self.config.block_group_size > 1:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
@ -1807,14 +1805,14 @@ class Molmo(nn.Module):
|
|||||||
if self.config.use_position_ids and attention_mask is None:
|
if self.config.use_position_ids and attention_mask is None:
|
||||||
attention_mask = input_ids != -1
|
attention_mask = input_ids != -1
|
||||||
|
|
||||||
if subsegment_ids is not None and attention_mask is not None:
|
if subsegment_ids is not None:
|
||||||
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
assert not use_cache, "Subsegment_ids cannot be used with cache."
|
||||||
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
||||||
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
|
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
raise ValueError("Positioned ids must be given if using subsegment_ids")
|
raise ValueError("Positioned ids must be given if using subsegment_ids")
|
||||||
else:
|
else:
|
||||||
if self.config.use_position_ids and position_ids is None and attention_mask is not None:
|
if self.config.use_position_ids and position_ids is None:
|
||||||
position_ids = torch.clamp(
|
position_ids = torch.clamp(
|
||||||
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
|
||||||
min=0,
|
min=0,
|
||||||
@ -1827,10 +1825,10 @@ class Molmo(nn.Module):
|
|||||||
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
||||||
|
|
||||||
num_image: Optional[int] = None
|
num_image: Optional[int] = None
|
||||||
if images is not None and image_input_idx is not None:
|
if images is not None:
|
||||||
# shape: (batch_size, num_image, num_patch, d_model)
|
# shape: (batch_size, num_image, num_patch, d_model)
|
||||||
# cls_embed: (batch_size, num_image, d_model)
|
# cls_embed: (batch_size, num_image, d_model)
|
||||||
image_features, cls_embed = self.vision_backbone(images, image_masks) # type: ignore
|
image_features, cls_embed = self.vision_backbone(images, image_masks)
|
||||||
num_image, num_patch = image_features.shape[1:3]
|
num_image, num_patch = image_features.shape[1:3]
|
||||||
assert image_input_idx.shape == (batch_size, num_image, num_patch)
|
assert image_input_idx.shape == (batch_size, num_image, num_patch)
|
||||||
|
|
||||||
@ -2011,8 +2009,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
rope_theta=config.rope_theta,
|
rope_theta=config.rope_theta,
|
||||||
layer_norm_eps=config.layer_norm_eps,
|
layer_norm_eps=config.layer_norm_eps,
|
||||||
layer_norm_type=config.layer_norm_type,
|
layer_norm_type=config.layer_norm_type,
|
||||||
vit_layers=[-2, -9], # type: ignore
|
vit_layers=[-2, -9],
|
||||||
vision_backbone=VisionBackboneConfig( # type: ignore
|
vision_backbone=VisionBackboneConfig(
|
||||||
image_default_input_size=(336, 336),
|
image_default_input_size=(336, 336),
|
||||||
image_patch_size=14,
|
image_patch_size=14,
|
||||||
image_pos_patch_size=14,
|
image_pos_patch_size=14,
|
||||||
@ -2056,7 +2054,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
append_last_valid_logits: Optional[torch.Tensor] = None,
|
append_last_valid_logits: Optional[torch.Tensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[ # type: ignore
|
cache_position: Optional[
|
||||||
Cache
|
Cache
|
||||||
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
@ -2082,7 +2080,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
last_logits_only=last_logits_only, # type: ignore
|
last_logits_only=last_logits_only,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
append_last_valid_logits=append_last_valid_logits,
|
append_last_valid_logits=append_last_valid_logits,
|
||||||
)
|
)
|
||||||
@ -2156,7 +2154,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|||||||
input_ids = batch["input_ids"]
|
input_ids = batch["input_ids"]
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
attention_mask = batch.get("attention_mask", None)
|
attention_mask = batch.get("attention_mask", None)
|
||||||
max_new_tokens = generation_config.max_new_tokens # type: ignore
|
max_new_tokens = generation_config.max_new_tokens
|
||||||
assert max_new_tokens is not None
|
assert max_new_tokens is not None
|
||||||
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
|
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
|
||||||
position_ids: Optional[torch.Tensor] = None
|
position_ids: Optional[torch.Tensor] = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user