2025-06-21 10:55:20 -05:00
|
|
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
|
|
|
# Source for "Build a Large Language Model From Scratch"
|
|
|
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
|
|
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
|
|
|
|
2025-06-23 18:08:49 -05:00
|
|
|
from .utils import KVCache
|
2025-06-21 10:55:20 -05:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
|
|
|
model.eval()
|
|
|
|
ctx_len = context_size or model.cfg["context_length"]
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
if use_cache:
|
2025-07-08 12:56:55 -05:00
|
|
|
cache = KVCache(n_layers=model.cfg["n_layers"])
|
2025-06-21 10:55:20 -05:00
|
|
|
model.reset_kv_cache()
|
2025-07-08 12:56:55 -05:00
|
|
|
logits = model(idx[:, -ctx_len:], cache=cache)
|
2025-06-21 10:55:20 -05:00
|
|
|
|
|
|
|
for _ in range(max_new_tokens):
|
|
|
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
|
|
|
idx = torch.cat([idx, next_idx], dim=1)
|
2025-07-08 12:56:55 -05:00
|
|
|
logits = model(next_idx, cache=cache)
|
2025-06-21 10:55:20 -05:00
|
|
|
else:
|
|
|
|
for _ in range(max_new_tokens):
|
2025-07-08 12:56:55 -05:00
|
|
|
logits = model(idx[:, -ctx_len:], cache=None)
|
2025-06-21 10:55:20 -05:00
|
|
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
|
|
|
idx = torch.cat([idx, next_idx], dim=1)
|
|
|
|
|
|
|
|
return idx
|
2025-08-01 19:13:17 -05:00
|
|
|
|
|
|
|
|
|
|
|
def generate_text_simple_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
cache = KVCache(n_layers=model.cfg["n_layers"])
|
|
|
|
model.reset_kv_cache()
|
|
|
|
|
|
|
|
# Prime the cache with the initial context
|
|
|
|
logits = model(token_ids, cache=cache)
|
|
|
|
|
|
|
|
for _ in range(max_new_tokens):
|
|
|
|
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
if eos_token_id is not None and torch.all(next_token == eos_token_id):
|
|
|
|
break
|
|
|
|
|
|
|
|
yield next_token
|
|
|
|
|
|
|
|
token_ids = torch.cat([token_ids, next_token], dim=1)
|
|
|
|
|
|
|
|
# Feed only the new token to the model; cache handles history
|
|
|
|
logits = model(next_token, cache=cache)
|