diff --git a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py index be5f3ba..37a4366 100644 --- a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py +++ b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py @@ -115,6 +115,10 @@ REPO_ID, LOCAL_DIR = build_repo_and_local(MODEL, REASONING, LOCAL_DIR) DEVICE = get_device(DEVICE) MODEL, TOKENIZER = get_model_and_tokenizer(QWEN3_CONFIG, REPO_ID, LOCAL_DIR, DEVICE, REASONING) +# Even though the official TOKENIZER.eos_token_id is either <|im_end|> (reasoning) +# or <|endoftext|> (base), the reasoning model sometimes emits both. +EOS_TOKEN_IDS = (TOKENIZER.encode("<|im_end|>")[0], TOKENIZER.encode("<|endoftext|>")[0]) + @chainlit.on_chat_start async def on_start(): @@ -147,9 +151,11 @@ async def main(message: chainlit.Message): model=MODEL, token_ids=input_ids_tensor, max_new_tokens=MAX_NEW_TOKENS, - eos_token_id=TOKENIZER.eos_token_id + # eos_token_id=TOKENIZER.eos_token_id ): token_id = tok.squeeze(0) + if token_id in EOS_TOKEN_IDS: + break piece = TOKENIZER.decode(token_id.tolist()) await out_msg.stream_token(piece)