diff --git a/ch06/01_main-chapter-code/previous_chapters.py b/ch06/01_main-chapter-code/previous_chapters.py index 59a5017..4fc0f7e 100644 --- a/ch06/01_main-chapter-code/previous_chapters.py +++ b/ch06/01_main-chapter-code/previous_chapters.py @@ -312,10 +312,10 @@ def load_weights_into_gpt(gpt, params): def text_to_token_ids(text, tokenizer): encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'}) - encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension + encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension return encoded_tensor def token_ids_to_text(token_ids, tokenizer): - flat = token_ids.squeeze(0) # remove batch dimension + flat = token_ids.squeeze(0) # remove batch dimension return tokenizer.decode(flat.tolist())