mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-03 02:21:11 +00:00
Improve multiturn stopping condition (#814)
* Improve multiturn stopping condition * improve
This commit is contained in:
parent
4b0021416a
commit
215abdbcdd
@ -115,6 +115,10 @@ REPO_ID, LOCAL_DIR = build_repo_and_local(MODEL, REASONING, LOCAL_DIR)
|
||||
DEVICE = get_device(DEVICE)
|
||||
MODEL, TOKENIZER = get_model_and_tokenizer(QWEN3_CONFIG, REPO_ID, LOCAL_DIR, DEVICE, REASONING)
|
||||
|
||||
# Even though the official TOKENIZER.eos_token_id is either <|im_end|> (reasoning)
|
||||
# or <|endoftext|> (base), the reasoning model sometimes emits both.
|
||||
EOS_TOKEN_IDS = (TOKENIZER.encode("<|im_end|>")[0], TOKENIZER.encode("<|endoftext|>")[0])
|
||||
|
||||
|
||||
@chainlit.on_chat_start
|
||||
async def on_start():
|
||||
@ -147,9 +151,11 @@ async def main(message: chainlit.Message):
|
||||
model=MODEL,
|
||||
token_ids=input_ids_tensor,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
eos_token_id=TOKENIZER.eos_token_id
|
||||
# eos_token_id=TOKENIZER.eos_token_id
|
||||
):
|
||||
token_id = tok.squeeze(0)
|
||||
if token_id in EOS_TOKEN_IDS:
|
||||
break
|
||||
piece = TOKENIZER.decode(token_id.tolist())
|
||||
await out_msg.stream_token(piece)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user