mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-08 13:29:01 +00:00
Improve multiturn stopping condition (#814)
* Improve multiturn stopping condition * improve
This commit is contained in:
parent
4b0021416a
commit
215abdbcdd
@ -115,6 +115,10 @@ REPO_ID, LOCAL_DIR = build_repo_and_local(MODEL, REASONING, LOCAL_DIR)
|
|||||||
DEVICE = get_device(DEVICE)
|
DEVICE = get_device(DEVICE)
|
||||||
MODEL, TOKENIZER = get_model_and_tokenizer(QWEN3_CONFIG, REPO_ID, LOCAL_DIR, DEVICE, REASONING)
|
MODEL, TOKENIZER = get_model_and_tokenizer(QWEN3_CONFIG, REPO_ID, LOCAL_DIR, DEVICE, REASONING)
|
||||||
|
|
||||||
|
# Even though the official TOKENIZER.eos_token_id is either <|im_end|> (reasoning)
|
||||||
|
# or <|endoftext|> (base), the reasoning model sometimes emits both.
|
||||||
|
EOS_TOKEN_IDS = (TOKENIZER.encode("<|im_end|>")[0], TOKENIZER.encode("<|endoftext|>")[0])
|
||||||
|
|
||||||
|
|
||||||
@chainlit.on_chat_start
|
@chainlit.on_chat_start
|
||||||
async def on_start():
|
async def on_start():
|
||||||
@ -147,9 +151,11 @@ async def main(message: chainlit.Message):
|
|||||||
model=MODEL,
|
model=MODEL,
|
||||||
token_ids=input_ids_tensor,
|
token_ids=input_ids_tensor,
|
||||||
max_new_tokens=MAX_NEW_TOKENS,
|
max_new_tokens=MAX_NEW_TOKENS,
|
||||||
eos_token_id=TOKENIZER.eos_token_id
|
# eos_token_id=TOKENIZER.eos_token_id
|
||||||
):
|
):
|
||||||
token_id = tok.squeeze(0)
|
token_id = tok.squeeze(0)
|
||||||
|
if token_id in EOS_TOKEN_IDS:
|
||||||
|
break
|
||||||
piece = TOKENIZER.decode(token_id.tolist())
|
piece = TOKENIZER.decode(token_id.tolist())
|
||||||
await out_msg.stream_token(piece)
|
await out_msg.stream_token(piece)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user