mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-10 14:57:26 +00:00
add text-to-token-id fn
This commit is contained in:
parent
d3201f5aad
commit
e9bdbf0725
@ -310,36 +310,12 @@ def load_weights_into_gpt(gpt, params):
|
|||||||
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
||||||
|
|
||||||
|
|
||||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
def text_to_token_ids(text, tokenizer):
|
||||||
# For-loop is the same as before: Get logits, and only focus on last time step
|
encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
|
||||||
for _ in range(max_new_tokens):
|
encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
|
||||||
idx_cond = idx[:, -context_size:]
|
return encoded_tensor
|
||||||
with torch.no_grad():
|
|
||||||
logits = model(idx_cond)
|
|
||||||
logits = logits[:, -1, :]
|
|
||||||
|
|
||||||
# New: Filter logits with top_k sampling
|
|
||||||
if top_k is not None:
|
|
||||||
# Keep only top_k values
|
|
||||||
top_logits, _ = torch.topk(logits, top_k)
|
|
||||||
min_val = top_logits[:, -1]
|
|
||||||
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
|
||||||
|
|
||||||
# New: Apply temperature scaling
|
def token_ids_to_text(token_ids, tokenizer):
|
||||||
if temperature > 0.0:
|
flat = token_ids.squeeze(0) # remove batch dimension
|
||||||
logits = logits / temperature
|
return tokenizer.decode(flat.tolist())
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
|
||||||
|
|
||||||
# Sample from the distribution
|
|
||||||
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
|
||||||
|
|
||||||
# Otherwise same as before: get idx of the vocab entry with the highest logits value
|
|
||||||
else:
|
|
||||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
|
||||||
|
|
||||||
# Same as before: append sampled index to the running sequence
|
|
||||||
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
|
||||||
|
|
||||||
return idx
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user