Chainlit bonus material fixes (#361)

* fix cmd

* moved idx to device

* improved code with clone().detach()

* fixed path

* fix: added extra line for pep8

* updated .gitginore

* Update ch05/06_user_interface/app_orig.py

* Update ch05/06_user_interface/app_own.py

* Apply suggestions from code review

---------

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
Daniel Kleine 2024-09-18 17:08:50 +02:00 committed by GitHub
parent ea9b4e83a4
commit eefe4bf12b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 10 deletions

1
.gitignore vendored
View File

@ -92,6 +92,7 @@ ch07/04_preference-tuning-with-dpo/loss-plot.pdf
# Other # Other
ch05/06_user_interface/chainlit.md ch05/06_user_interface/chainlit.md
ch05/06_user_interface/.chainlit ch05/06_user_interface/.chainlit
ch05/06_user_interface/.files
# Temporary OS-related files # Temporary OS-related files
.DS_Store .DS_Store

View File

@ -17,7 +17,7 @@ To implement this user interface, we use the open-source [Chainlit Python packag
First, we install the `chainlit` package via First, we install the `chainlit` package via
```python ```bash
pip install chainlit pip install chainlit
``` ```

View File

@ -16,6 +16,8 @@ from previous_chapters import (
token_ids_to_text, token_ids_to_text,
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_model_and_tokenizer(): def get_model_and_tokenizer():
""" """
@ -44,8 +46,6 @@ def get_model_and_tokenizer():
BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
gpt = GPTModel(BASE_CONFIG) gpt = GPTModel(BASE_CONFIG)
@ -67,9 +67,9 @@ async def main(message: chainlit.Message):
""" """
The main Chainlit function. The main Chainlit function.
""" """
token_ids = generate( token_ids = generate( # function uses `with torch.no_grad()` internally already
model=model, model=model,
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content` idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
max_new_tokens=50, max_new_tokens=50,
context_size=model_config["context_length"], context_size=model_config["context_length"],
top_k=1, top_k=1,

View File

@ -17,6 +17,8 @@ from previous_chapters import (
token_ids_to_text, token_ids_to_text,
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_model_and_tokenizer(): def get_model_and_tokenizer():
""" """
@ -34,8 +36,6 @@ def get_model_and_tokenizer():
"qkv_bias": False # Query-key-value bias "qkv_bias": False # Query-key-value bias
} }
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tiktoken.get_encoding("gpt2") tokenizer = tiktoken.get_encoding("gpt2")
model_path = Path("..") / "01_main-chapter-code" / "model.pth" model_path = Path("..") / "01_main-chapter-code" / "model.pth"
@ -43,7 +43,7 @@ def get_model_and_tokenizer():
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.") print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
sys.exit() sys.exit()
checkpoint = torch.load("model.pth", weights_only=True) checkpoint = torch.load(model_path, weights_only=True)
model = GPTModel(GPT_CONFIG_124M) model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
model.to(device) model.to(device)
@ -60,9 +60,9 @@ async def main(message: chainlit.Message):
""" """
The main Chainlit function. The main Chainlit function.
""" """
token_ids = generate( token_ids = generate( # function uses `with torch.no_grad()` internally already
model=model, model=model,
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content` idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
max_new_tokens=50, max_new_tokens=50,
context_size=model_config["context_length"], context_size=model_config["context_length"],
top_k=1, top_k=1,