mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-27 07:02:08 +00:00
* Fix MHAEinsum weight dimension bug when d_in != d_out (#857) Previously MHAEinsum initialized weight matrices with shape (d_out, d_in) and used inappropriate einsum notation, causing failures for non-square input-output dimensions. This commit corrects weight initialization to shape (d_in, d_out), updates einsum notation to 'bnd,do->bno', and adds three unit tests to verify parity across different d_in and d_out settings. All tests pass successfully. * use pytest * Update .gitignore --------- Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
parent
b1db33b384
commit
27d52d6378
1
.github/workflows/basic-tests-linux-uv.yml
vendored
1
.github/workflows/basic-tests-linux-uv.yml
vendored
@ -48,6 +48,7 @@ jobs:
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest setup/02_installing-python-libraries/tests.py
|
||||
pytest ch03/02_bonus_efficient-multihead-attention/tests/test_mha_implementations.py
|
||||
pytest ch04/01_main-chapter-code/tests.py
|
||||
pytest ch04/03_kv-cache/tests.py
|
||||
pytest ch05/01_main-chapter-code/tests.py
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -328,3 +328,5 @@ cython_debug/
|
||||
# pixi environments
|
||||
.pixi
|
||||
*.egg-info
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -0,0 +1,63 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
|
||||
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nb_imports():
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
mod = import_definitions_from_notebook(nb_dir, "mha-implementations.ipynb")
|
||||
return mod
|
||||
|
||||
|
||||
def copy_weights(from_mha, to_mha):
|
||||
with torch.no_grad():
|
||||
to_mha.W_query.copy_(from_mha.W_query.weight.T)
|
||||
to_mha.W_key.copy_(from_mha.W_key.weight.T)
|
||||
to_mha.W_value.copy_(from_mha.W_value.weight.T)
|
||||
|
||||
to_mha.out_proj.weight.copy_(from_mha.out_proj.weight)
|
||||
to_mha.out_proj.bias.copy_(from_mha.out_proj.bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"d_in,d_out,batch,seq_len,num_heads,seed",
|
||||
[
|
||||
(768, 768, 2, 4, 12, 123), # d_in == d_out
|
||||
(768, 1536, 2, 4, 12, 456), # d_in != d_out
|
||||
(1024, 512, 2, 4, 8, 789), # d_in > d_out
|
||||
],
|
||||
)
|
||||
def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, nb_imports):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.randn(batch, seq_len, d_in)
|
||||
|
||||
mha_linear = nb_imports.Ch03_MHA(
|
||||
d_in=d_in,
|
||||
d_out=d_out,
|
||||
context_length=seq_len,
|
||||
dropout=0.0,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=False,
|
||||
).eval()
|
||||
|
||||
mha_einsum = nb_imports.MHAEinsum(
|
||||
d_in=d_in,
|
||||
d_out=d_out,
|
||||
context_length=seq_len,
|
||||
dropout=0.0,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=False,
|
||||
).eval()
|
||||
|
||||
copy_weights(mha_linear, mha_einsum)
|
||||
|
||||
out_linear = mha_linear(x)
|
||||
out_einsum = mha_einsum(x)
|
||||
|
||||
assert out_linear.shape == out_einsum.shape == torch.Size([batch, seq_len, d_out])
|
||||
assert torch.allclose(out_linear, out_einsum, atol=1e-5)
|
||||
Loading…
x
Reference in New Issue
Block a user