More efficient angles computation in RoPE (#830)

This commit is contained in:
Sebastian Raschka 2025-09-15 22:23:33 -05:00 committed by GitHub
parent 147dc49ab5
commit b6cd0a312f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 11 additions and 11 deletions

View File

@ -435,7 +435,7 @@
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -310,7 +310,7 @@
" positions = torch.arange(context_length)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -180,7 +180,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -275,7 +275,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -275,7 +275,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -206,7 +206,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -204,7 +204,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -200,7 +200,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -200,7 +200,7 @@
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)\n",
"\n",
" # Expand angles to match the head_dim\n",
" angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n",

View File

@ -238,7 +238,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c
positions = torch.arange(context_length, dtype=dtype)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)

View File

@ -326,7 +326,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=
positions = torch.arange(context_length, dtype=dtype)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)