mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-02 10:50:30 +00:00
add text-to-token-id fn
This commit is contained in:
parent
c6fcadb087
commit
244593ce01
@ -310,36 +310,12 @@ def load_weights_into_gpt(gpt, params):
|
||||
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
||||
|
||||
|
||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
||||
# For-loop is the same as before: Get logits, and only focus on last time step
|
||||
for _ in range(max_new_tokens):
|
||||
idx_cond = idx[:, -context_size:]
|
||||
with torch.no_grad():
|
||||
logits = model(idx_cond)
|
||||
logits = logits[:, -1, :]
|
||||
def text_to_token_ids(text, tokenizer):
|
||||
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
|
||||
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
|
||||
return encoded_tensor
|
||||
|
||||
# New: Filter logits with top_k sampling
|
||||
if top_k is not None:
|
||||
# Keep only top_k values
|
||||
top_logits, _ = torch.topk(logits, top_k)
|
||||
min_val = top_logits[:, -1]
|
||||
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
||||
|
||||
# New: Apply temperature scaling
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
# Sample from the distribution
|
||||
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
||||
|
||||
# Otherwise same as before: get idx of the vocab entry with the highest logits value
|
||||
else:
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
||||
|
||||
# Same as before: append sampled index to the running sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
||||
|
||||
return idx
|
||||
def token_ids_to_text(token_ids, tokenizer):
|
||||
flat = token_ids.squeeze(0) # remove batch dimension
|
||||
return tokenizer.decode(flat.tolist())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user