diff --git a/ch06/01_main-chapter-code/previous_chapters.py b/ch06/01_main-chapter-code/previous_chapters.py index e794f9b..59a5017 100644 --- a/ch06/01_main-chapter-code/previous_chapters.py +++ b/ch06/01_main-chapter-code/previous_chapters.py @@ -310,36 +310,12 @@ def load_weights_into_gpt(gpt, params): gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"]) -def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None): - # For-loop is the same as before: Get logits, and only focus on last time step - for _ in range(max_new_tokens): - idx_cond = idx[:, -context_size:] - with torch.no_grad(): - logits = model(idx_cond) - logits = logits[:, -1, :] +def text_to_token_ids(text, tokenizer): + encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'}) + encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension + return encoded_tensor - # New: Filter logits with top_k sampling - if top_k is not None: - # Keep only top_k values - top_logits, _ = torch.topk(logits, top_k) - min_val = top_logits[:, -1] - logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) - # New: Apply temperature scaling - if temperature > 0.0: - logits = logits / temperature - - # Apply softmax to get probabilities - probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) - - # Sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) - - # Otherwise same as before: get idx of the vocab entry with the highest logits value - else: - idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) - - # Same as before: append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) - - return idx +def token_ids_to_text(token_ids, tokenizer): + flat = token_ids.squeeze(0) # remove batch dimension + return tokenizer.decode(flat.tolist())