Make quote style consistent (#891)

This commit is contained in:
Sebastian Raschka 2025-10-21 19:42:33 -05:00 committed by GitHub
parent 9276edbc37
commit 7ca7c47e4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 239 additions and 81 deletions

158
.github/scripts/check_double_quotes.py vendored Normal file
View File

@ -0,0 +1,158 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
# Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
# Code repository: https://github.com/rasbt/reasoning-from-scratch
# Verify that Python source files (and optionally notebooks) use double quotes for strings.
import argparse
import ast
import io
import json
import sys
import tokenize
from pathlib import Path
EXCLUDED_DIRS = {
".git",
".hg",
".mypy_cache",
".pytest_cache",
".ruff_cache",
".svn",
".tox",
".venv",
"__pycache__",
"build",
"dist",
"node_modules",
}
PREFIX_CHARS = {"r", "u", "f", "b"}
SINGLE_QUOTE = "'"
DOUBLE_QUOTE = "\""
TRIPLE_SINGLE = SINGLE_QUOTE * 3
TRIPLE_DOUBLE = DOUBLE_QUOTE * 3
def should_skip(path):
parts = set(path.parts)
return bool(EXCLUDED_DIRS & parts)
def collect_fstring_expr_string_positions(source):
"""
Return set of (lineno, col_offset) for string literals that appear inside
formatted expressions of f-strings. These should be exempt from the double
quote check, since enforcing double quotes there is unnecessarily strict.
"""
try:
tree = ast.parse(source)
except SyntaxError:
return set()
positions = set()
class Collector(ast.NodeVisitor):
def visit_JoinedStr(self, node):
for value in node.values:
if isinstance(value, ast.FormattedValue):
self._collect_from_expr(value.value)
# Continue walking to catch nested f-strings within expressions
self.generic_visit(node)
def _collect_from_expr(self, node):
if isinstance(node, ast.Constant) and isinstance(node.value, str):
positions.add((node.lineno, node.col_offset))
elif isinstance(node, ast.Str): # Python <3.8 compatibility
positions.add((node.lineno, node.col_offset))
else:
for child in ast.iter_child_nodes(node):
self._collect_from_expr(child)
Collector().visit(tree)
return positions
def check_quotes_in_source(source, path):
violations = []
ignored_positions = collect_fstring_expr_string_positions(source)
tokens = tokenize.generate_tokens(io.StringIO(source).readline)
for tok_type, tok_str, start, _, _ in tokens:
if tok_type == tokenize.STRING:
if start in ignored_positions:
continue
lowered = tok_str.lower()
# ignore triple-quoted strings
if lowered.startswith((TRIPLE_DOUBLE, TRIPLE_SINGLE)):
continue
# find the prefix and quote type
# prefix = ""
for c in PREFIX_CHARS:
if lowered.startswith(c):
# prefix = c
lowered = lowered[1:]
break
# report if not using double quotes
if lowered.startswith(SINGLE_QUOTE):
line, col = start
violations.append(f"{path}:{line}:{col}: uses single quotes")
return violations
def check_file(path):
try:
if path.suffix == ".ipynb":
return check_notebook(path)
else:
text = path.read_text(encoding="utf-8")
return check_quotes_in_source(text, path)
except Exception as e:
return [f"{path}: failed to check ({e})"]
def check_notebook(path):
violations = []
with open(path, encoding="utf-8") as f:
nb = json.load(f)
for cell in nb.get("cells", []):
if cell.get("cell_type") == "code":
src = "".join(cell.get("source", []))
violations.extend(check_quotes_in_source(src, path))
return violations
def parse_args():
parser = argparse.ArgumentParser(description="Verify double-quoted string literals.")
parser.add_argument(
"--include-notebooks",
action="store_true",
help="Also scan Jupyter notebooks (.ipynb files) for single-quoted strings.",
)
return parser.parse_args()
def main():
args = parse_args()
project_root = Path(".").resolve()
py_files = sorted(project_root.rglob("*.py"))
notebook_files = sorted(project_root.rglob("*.ipynb")) if args.include_notebooks else []
violations = []
for path in py_files + notebook_files:
if should_skip(path):
continue
violations.extend(check_file(path))
if violations:
print("\n".join(violations))
print(f"\n{len(violations)} violations found.")
return 1
print("All files use double quotes correctly.")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -73,7 +73,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape

View File

@ -80,7 +80,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -257,8 +257,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
@ -318,7 +318,7 @@ def load_weights_into_gpt(gpt, params):
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor

View File

@ -70,7 +70,7 @@ def get_pairs(word):
class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace'):
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
@ -92,7 +92,7 @@ class Encoder:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
@ -119,43 +119,43 @@ class Encoder:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def get_encoder(model_name, models_dir):
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
encoder = json.load(f)
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
return Encoder(encoder=encoder, bpe_merges=bpe_merges)
def download_vocab():
# Modified code from
subdir = 'gpt2_model'
subdir = "gpt2_model"
if not os.path.exists(subdir):
os.makedirs(subdir)
subdir = subdir.replace('\\', '/') # needed for Windows
subdir = subdir.replace("\\", "/") # needed for Windows
for filename in ['encoder.json', 'vocab.bpe']:
for filename in ["encoder.json", "vocab.bpe"]:
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M/" + filename, stream=True)
with open(os.path.join(subdir, filename), 'wb') as f:
with open(os.path.join(subdir, filename), "wb") as f:
file_size = int(r.headers["content-length"])
chunk_size = 1000
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:

View File

@ -60,7 +60,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape

View File

@ -33,8 +33,8 @@ def test_main(capsys):
captured = capsys.readouterr()
# Normalize line endings and strip trailing whitespace from each line
normalized_expected = '\n'.join(line.rstrip() for line in expected.splitlines())
normalized_output = '\n'.join(line.rstrip() for line in captured.out.splitlines())
normalized_expected = "\n".join(line.rstrip() for line in expected.splitlines())
normalized_output = "\n".join(line.rstrip() for line in captured.out.splitlines())
# Compare normalized strings
assert normalized_output == normalized_expected

View File

@ -71,7 +71,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape

View File

@ -43,7 +43,7 @@ def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftex
content = strip_headers(content)
# Regular expression to replace multiple blank lines with a single blank line
content = re.sub(r'\n\s*\n', '\n\n', content)
content = re.sub(r"\n\s*\n", "\n\n", content)
estimated_size = len(content.encode("utf-8"))
if current_size + estimated_size > max_size_mb * 1024 * 1024:

View File

@ -148,26 +148,26 @@ def train_model_simple(model, optimizer, device, n_epochs,
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GPT Model Training Configuration')
parser = argparse.ArgumentParser(description="GPT Model Training Configuration")
parser.add_argument('--data_dir', type=str, default='gutenberg/data',
help='Directory containing the training data')
parser.add_argument('--output_dir', type=str, default='model_checkpoints',
help='Directory where the model checkpoints will be saved')
parser.add_argument('--n_epochs', type=int, default=1,
help='Number of epochs to train the model')
parser.add_argument('--print_sample_iter', type=int, default=1000,
help='Iterations between printing sample outputs')
parser.add_argument('--eval_freq', type=int, default=100,
help='Frequency of evaluations during training')
parser.add_argument('--save_ckpt_freq', type=int, default=100_000,
help='Frequency of saving model checkpoints during training')
parser.add_argument('--lr', type=float, default=5e-4,
help='Learning rate for the optimizer')
parser.add_argument('--batch_size', type=int, default=4,
help='Batch size for training')
parser.add_argument('--debug', type=bool, default=False,
help='Uses a very small model for debugging purposes')
parser.add_argument("--data_dir", type=str, default="gutenberg/data",
help="Directory containing the training data")
parser.add_argument("--output_dir", type=str, default="model_checkpoints",
help="Directory where the model checkpoints will be saved")
parser.add_argument("--n_epochs", type=int, default=1,
help="Number of epochs to train the model")
parser.add_argument("--print_sample_iter", type=int, default=1000,
help="Iterations between printing sample outputs")
parser.add_argument("--eval_freq", type=int, default=100,
help="Frequency of evaluations during training")
parser.add_argument("--save_ckpt_freq", type=int, default=100_000,
help="Frequency of saving model checkpoints during training")
parser.add_argument("--lr", type=float, default=5e-4,
help="Learning rate for the optimizer")
parser.add_argument("--batch_size", type=int, default=4,
help="Batch size for training")
parser.add_argument("--debug", type=bool, default=False,
help="Uses a very small model for debugging purposes")
args = parser.parse_args()

View File

@ -118,7 +118,7 @@ if __name__ == "__main__":
print(f"Total hyperparameter configurations: {total_combinations}")
# Placeholder for the best loss and best hyperparameters
best_val_loss = float('inf')
best_val_loss = float("inf")
best_hparams = {}
script_path = os.path.abspath(__file__)

View File

@ -38,7 +38,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:

View File

@ -29,7 +29,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape

View File

@ -426,7 +426,7 @@ def main(gpt_config, settings):
if not os.path.exists(file_path):
with urllib.request.urlopen(url) as response:
text_data = response.read().decode('utf-8')
text_data = response.read().decode("utf-8")
with open(file_path, "w", encoding="utf-8") as file:
file.write(text_data)
else:

View File

@ -72,7 +72,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -249,8 +249,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
@ -310,7 +310,7 @@ def load_weights_into_gpt(gpt, params):
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor

View File

@ -446,7 +446,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--average_embeddings",
action='store_true',
action="store_true",
default=False,
help=(
"Average the output embeddings from all tokens instead of using"
@ -480,7 +480,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--no_padding",
action='store_true',
action="store_true",
default=False,
help=(
"Disable padding, which means each example may have a different length."
@ -517,7 +517,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--disable_causal_mask",
action='store_true',
action="store_true",
default=False,
help=(
"Disables the causal attention mask."

View File

@ -74,7 +74,7 @@ class MultiHeadAttention(nn.Module):
self.dropout = nn.Dropout(dropout)
if not disable_causal_mask:
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.disable_causal_mask = disable_causal_mask
def forward(self, x):
@ -255,8 +255,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
@ -328,7 +328,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:

View File

@ -73,7 +73,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -250,8 +250,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
@ -311,7 +311,7 @@ def load_weights_into_gpt(gpt, params):
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
return encoded_tensor

View File

@ -261,7 +261,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--average_embeddings",
action='store_true',
action="store_true",
default=False,
help=(
"Average the output embeddings from all tokens instead of using"

View File

@ -77,7 +77,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -261,7 +261,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:
@ -356,8 +356,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(

View File

@ -34,7 +34,7 @@ def preprocess_text(text):
# Lowercase the text
text = text.lower()
# Remove punctuation
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r"[^\w\s]", "", text)
return text
@ -50,7 +50,7 @@ def find_near_duplicates(json_data, threshold=0.75, key="instruction"):
return {}, near_duplicates
# Vectorize the text data
vectorizer = TfidfVectorizer(stop_words=None, analyzer='char', ngram_range=(1, 3))
vectorizer = TfidfVectorizer(stop_words=None, analyzer="char", ngram_range=(1, 3))
tfidf_matrix = vectorizer.fit_transform(text)
# Compute cosine similarity between each pair of entries
@ -84,7 +84,7 @@ def find_print_and_remove_near_duplicates(json_data, remove_duplicates=False, th
json_data, near_duplicates = find_near_duplicates(json_data, key=key, threshold=threshold)
else:
_, near_duplicates = find_near_duplicates(json_data, key=key, threshold=threshold)
separator = 50 * '='
separator = 50 * "="
print(f"\n\n{separator}\nSearching '{key}' for duplicates ...\n{separator}")
if not near_duplicates:
print("No duplicates found")
@ -114,7 +114,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--remove_duplicates",
action='store_true',
action="store_true",
default=False,
help=(
"Removes duplicates based on the 'input' or 'output' keys "

View File

@ -77,7 +77,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
@ -261,7 +261,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:
@ -357,8 +357,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(

View File

@ -59,7 +59,7 @@ class CausalAttention(nn.Module):
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) # New
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape # New batch dimension b
@ -109,7 +109,7 @@ class MultiHeadAttention(nn.Module):
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape

View File

@ -30,7 +30,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0:
@ -125,8 +125,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(

View File

@ -110,7 +110,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_moe["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
assert any(hasattr(block.ff, 'gate') for block in model.trf_blocks), \
assert any(hasattr(block.ff, "gate") for block in model.trf_blocks), \
"Expected MoEFeedForward in at least one transformer block"