mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-29 02:50:15 +00:00
new experiment w/o causal mask
This commit is contained in:
parent
00a466f0b9
commit
bdea15f6c6
@ -23,6 +23,7 @@ For example,
|
||||
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
|
||||
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
|
||||
| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
|
||||
| 13 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
|
||||
|
||||
|
||||
|
||||
@ -43,6 +44,7 @@ You can use the following code to reproduce the experiments:
|
||||
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
|
||||
- Row 11: `python additional-experiments.py --no_padding --batch_size 1`
|
||||
- Row 12: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
|
||||
- Row 13: `python additional-experiments.py --disable_causal_mask`
|
||||
|
||||
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU.
|
||||
|
||||
@ -65,3 +67,5 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
|
||||
7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse.
|
||||
|
||||
8. **Padding vs no padding (Row 1 vs. 11 and 12)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 12, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
|
||||
|
||||
9. **Disabling the causal attention mask (Row 1 vs. 13)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
|
||||
|
@ -153,7 +153,7 @@ def instantiate_model(choose_model, load_weights):
|
||||
|
||||
if not load_weights:
|
||||
torch.manual_seed(123)
|
||||
model = GPTModel(BASE_CONFIG)
|
||||
model = GPTModel(BASE_CONFIG, disable_causal_mask=args.disable_causal_mask)
|
||||
|
||||
if load_weights:
|
||||
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
|
||||
@ -386,6 +386,15 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_causal_mask",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=(
|
||||
"Disables the causal attention mask."
|
||||
)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token == "first":
|
||||
|
@ -60,7 +60,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
# Chapter 3
|
||||
#####################################
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
|
||||
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, disable_causal_mask=False):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
|
||||
|
||||
@ -73,7 +73,10 @@ class MultiHeadAttention(nn.Module):
|
||||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
|
||||
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
|
||||
|
||||
if not disable_causal_mask:
|
||||
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
|
||||
self.disable_causal_mask = disable_causal_mask
|
||||
|
||||
def forward(self, x):
|
||||
b, num_tokens, d_in = x.shape
|
||||
@ -96,11 +99,12 @@ class MultiHeadAttention(nn.Module):
|
||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
|
||||
# Original mask truncated to the number of tokens and converted to boolean
|
||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||
if not self.disable_causal_mask:
|
||||
# Original mask truncated to the number of tokens and converted to boolean
|
||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||
# Use the mask to fill attention scores
|
||||
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||
|
||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
@ -157,7 +161,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, disable_causal_mask=False):
|
||||
super().__init__()
|
||||
self.att = MultiHeadAttention(
|
||||
d_in=cfg["emb_dim"],
|
||||
@ -165,7 +169,9 @@ class TransformerBlock(nn.Module):
|
||||
context_length=cfg["context_length"],
|
||||
num_heads=cfg["n_heads"],
|
||||
dropout=cfg["drop_rate"],
|
||||
qkv_bias=cfg["qkv_bias"])
|
||||
qkv_bias=cfg["qkv_bias"],
|
||||
disable_causal_mask=disable_causal_mask
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = LayerNorm(cfg["emb_dim"])
|
||||
self.norm2 = LayerNorm(cfg["emb_dim"])
|
||||
@ -190,14 +196,14 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class GPTModel(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, disable_causal_mask=False):
|
||||
super().__init__()
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
|
||||
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
|
||||
self.drop_emb = nn.Dropout(cfg["drop_rate"])
|
||||
|
||||
self.trf_blocks = nn.Sequential(
|
||||
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
|
||||
*[TransformerBlock(cfg, disable_causal_mask) for _ in range(cfg["n_layers"])])
|
||||
|
||||
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user