mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-12 16:15:22 +00:00
removed unnecessary imports (#367)
This commit is contained in:
parent
76e9a9ec02
commit
5f36d2af4c
@ -52,7 +52,6 @@ def get_model_and_tokenizer():
|
|||||||
model.out_head = torch.nn.Linear(in_features=GPT_CONFIG_124M["emb_dim"], out_features=num_classes)
|
model.out_head = torch.nn.Linear(in_features=GPT_CONFIG_124M["emb_dim"], out_features=num_classes)
|
||||||
|
|
||||||
# Then load model weights
|
# Then load model weights
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
|
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -71,7 +70,6 @@ async def main(message: chainlit.Message):
|
|||||||
The main Chainlit function.
|
The main Chainlit function.
|
||||||
"""
|
"""
|
||||||
user_input = message.content
|
user_input = message.content
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
label = classify_review(user_input, model, tokenizer, device, max_length=120)
|
label = classify_review(user_input, model, tokenizer, device, max_length=120)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user