Chainlit bonus material fixes (#361)

* fix cmd

* moved idx to device

* improved code with clone().detach()

* fixed path

* fix: added extra line for pep8

* updated .gitginore

* Update ch05/06_user_interface/app_orig.py

* Update ch05/06_user_interface/app_own.py

* Apply suggestions from code review

---------

Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
Daniel Kleine 2024-09-18 17:08:50 +02:00 committed by GitHub
parent 1bc560fb13
commit 92ad9570e4
4 changed files with 11 additions and 10 deletions

1
.gitignore vendored
View File

@ -92,6 +92,7 @@ ch07/04_preference-tuning-with-dpo/loss-plot.pdf
# Other
ch05/06_user_interface/chainlit.md
ch05/06_user_interface/.chainlit
ch05/06_user_interface/.files
# Temporary OS-related files
.DS_Store

View File

@ -17,7 +17,7 @@ To implement this user interface, we use the open-source [Chainlit Python packag
First, we install the `chainlit` package via
```python
```bash
pip install chainlit
```

View File

@ -16,6 +16,8 @@ from previous_chapters import (
token_ids_to_text,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_model_and_tokenizer():
"""
@ -44,8 +46,6 @@ def get_model_and_tokenizer():
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
gpt = GPTModel(BASE_CONFIG)
@ -67,9 +67,9 @@ async def main(message: chainlit.Message):
"""
The main Chainlit function.
"""
token_ids = generate(
token_ids = generate( # function uses `with torch.no_grad()` internally already
model=model,
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
max_new_tokens=50,
context_size=model_config["context_length"],
top_k=1,

View File

@ -17,6 +17,8 @@ from previous_chapters import (
token_ids_to_text,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_model_and_tokenizer():
"""
@ -34,8 +36,6 @@ def get_model_and_tokenizer():
"qkv_bias": False # Query-key-value bias
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tiktoken.get_encoding("gpt2")
model_path = Path("..") / "01_main-chapter-code" / "model.pth"
@ -43,7 +43,7 @@ def get_model_and_tokenizer():
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
sys.exit()
checkpoint = torch.load("model.pth", weights_only=True)
checkpoint = torch.load(model_path, weights_only=True)
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint)
model.to(device)
@ -60,9 +60,9 @@ async def main(message: chainlit.Message):
"""
The main Chainlit function.
"""
token_ids = generate(
token_ids = generate( # function uses `with torch.no_grad()` internally already
model=model,
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
max_new_tokens=50,
context_size=model_config["context_length"],
top_k=1,