new experiment w/o causal mask

This commit is contained in:
rasbt 2024-05-18 17:03:36 -05:00
parent 00a466f0b9
commit bdea15f6c6
3 changed files with 30 additions and 11 deletions

View File

@ -23,6 +23,7 @@ For example,
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
| 13 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
 
@ -43,6 +44,7 @@ You can use the following code to reproduce the experiments:
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
- Row 11: `python additional-experiments.py --no_padding --batch_size 1`
- Row 12: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
- Row 13: `python additional-experiments.py --disable_causal_mask`
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU.
@ -65,3 +67,5 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse.
8. **Padding vs no padding (Row 1 vs. 11 and 12)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 12, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
9. **Disabling the causal attention mask (Row 1 vs. 13)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.

View File

@ -153,7 +153,7 @@ def instantiate_model(choose_model, load_weights):
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
model = GPTModel(BASE_CONFIG, disable_causal_mask=args.disable_causal_mask)
if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
@ -386,6 +386,15 @@ if __name__ == "__main__":
)
)
parser.add_argument(
"--disable_causal_mask",
action='store_true',
default=False,
help=(
"Disables the causal attention mask."
)
)
args = parser.parse_args()
if args.trainable_token == "first":

View File

@ -60,7 +60,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, disable_causal_mask=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
@ -73,7 +73,10 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
if not disable_causal_mask:
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.disable_causal_mask = disable_causal_mask
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -96,11 +99,12 @@ class MultiHeadAttention(nn.Module):
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
if not self.disable_causal_mask:
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
@ -157,7 +161,7 @@ class FeedForward(nn.Module):
class TransformerBlock(nn.Module):
def __init__(self, cfg):
def __init__(self, cfg, disable_causal_mask=False):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
@ -165,7 +169,9 @@ class TransformerBlock(nn.Module):
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
qkv_bias=cfg["qkv_bias"],
disable_causal_mask=disable_causal_mask
)
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
@ -190,14 +196,14 @@ class TransformerBlock(nn.Module):
class GPTModel(nn.Module):
def __init__(self, cfg):
def __init__(self, cfg, disable_causal_mask=False):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
*[TransformerBlock(cfg, disable_causal_mask) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)