mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-14 19:51:35 +00:00
30 lines
1.1 KiB
Python
30 lines
1.1 KiB
Python
![]() |
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||
|
# Source for "Build a Large Language Model From Scratch"
|
||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||
|
model.eval()
|
||
|
|
||
|
ctx_len = context_size or model.cfg["context_length"]
|
||
|
|
||
|
with torch.no_grad():
|
||
|
if use_cache:
|
||
|
model.reset_kv_cache()
|
||
|
logits = model(idx[:, -ctx_len:], use_cache=True)
|
||
|
|
||
|
for _ in range(max_new_tokens):
|
||
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||
|
logits = model(next_idx, use_cache=True)
|
||
|
else:
|
||
|
for _ in range(max_new_tokens):
|
||
|
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||
|
idx = torch.cat([idx, next_idx], dim=1)
|
||
|
|
||
|
return idx
|