mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-16 12:41:42 +00:00
Fix truncation issue in classify_review function (#373)
This commit is contained in:
parent
b56d0b2942
commit
7ef5129e18
@ -2207,7 +2207,9 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Prepare inputs to the model\n",
|
" # Prepare inputs to the model\n",
|
||||||
" input_ids = tokenizer.encode(text)\n",
|
" input_ids = tokenizer.encode(text)\n",
|
||||||
" supported_context_length = model.pos_emb.weight.shape[1]\n",
|
" supported_context_length = model.pos_emb.weight.shape[0]\n",
|
||||||
|
" # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake\n",
|
||||||
|
" # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Truncate sequences if they too long\n",
|
" # Truncate sequences if they too long\n",
|
||||||
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
|
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
|
||||||
|
@ -179,7 +179,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Prepare inputs to the model\n",
|
" # Prepare inputs to the model\n",
|
||||||
" input_ids = tokenizer.encode(text)\n",
|
" input_ids = tokenizer.encode(text)\n",
|
||||||
" supported_context_length = model.pos_emb.weight.shape[1]\n",
|
" supported_context_length = model.pos_emb.weight.shape[0]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Truncate sequences if they too long\n",
|
" # Truncate sequences if they too long\n",
|
||||||
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
|
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
|
||||||
|
@ -353,7 +353,7 @@ def classify_review(text, model, tokenizer, device, max_length=None, pad_token_i
|
|||||||
|
|
||||||
# Prepare inputs to the model
|
# Prepare inputs to the model
|
||||||
input_ids = tokenizer.encode(text)
|
input_ids = tokenizer.encode(text)
|
||||||
supported_context_length = model.pos_emb.weight.shape[1]
|
supported_context_length = model.pos_emb.weight.shape[0]
|
||||||
|
|
||||||
# Truncate sequences if they too long
|
# Truncate sequences if they too long
|
||||||
input_ids = input_ids[:min(max_length, supported_context_length)]
|
input_ids = input_ids[:min(max_length, supported_context_length)]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user