Improve multiturn stopping condition (#814)

* Improve multiturn stopping condition

* improve
This commit is contained in:
Sebastian Raschka 2025-09-09 19:37:15 -05:00 committed by GitHub
parent 4b0021416a
commit 215abdbcdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,6 +115,10 @@ REPO_ID, LOCAL_DIR = build_repo_and_local(MODEL, REASONING, LOCAL_DIR)
DEVICE = get_device(DEVICE)
MODEL, TOKENIZER = get_model_and_tokenizer(QWEN3_CONFIG, REPO_ID, LOCAL_DIR, DEVICE, REASONING)
# Even though the official TOKENIZER.eos_token_id is either <|im_end|> (reasoning)
# or <|endoftext|> (base), the reasoning model sometimes emits both.
EOS_TOKEN_IDS = (TOKENIZER.encode("<|im_end|>")[0], TOKENIZER.encode("<|endoftext|>")[0])
@chainlit.on_chat_start
async def on_start():
@ -147,9 +151,11 @@ async def main(message: chainlit.Message):
model=MODEL,
token_ids=input_ids_tensor,
max_new_tokens=MAX_NEW_TOKENS,
eos_token_id=TOKENIZER.eos_token_id
# eos_token_id=TOKENIZER.eos_token_id
):
token_id = tok.squeeze(0)
if token_id in EOS_TOKEN_IDS:
break
piece = TOKENIZER.decode(token_id.tolist())
await out_msg.stream_token(piece)