removed unnecessary imports (#367)

This commit is contained in:
Daniel Kleine 2024-09-22 18:59:37 +02:00 committed by GitHub
parent 76e9a9ec02
commit 5f36d2af4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,7 +52,6 @@ def get_model_and_tokenizer():
model.out_head = torch.nn.Linear(in_features=GPT_CONFIG_124M["emb_dim"], out_features=num_classes)
# Then load model weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint)
model.to(device)
@ -71,7 +70,6 @@ async def main(message: chainlit.Message):
The main Chainlit function.
"""
user_input = message.content
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
label = classify_review(user_input, model, tokenizer, device, max_length=120)