mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-11 15:02:29 +00:00
Chainlit bonus material fixes (#361)
* fix cmd * moved idx to device * improved code with clone().detach() * fixed path * fix: added extra line for pep8 * updated .gitginore * Update ch05/06_user_interface/app_orig.py * Update ch05/06_user_interface/app_own.py * Apply suggestions from code review --------- Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
parent
ea9b4e83a4
commit
eefe4bf12b
1
.gitignore
vendored
1
.gitignore
vendored
@ -92,6 +92,7 @@ ch07/04_preference-tuning-with-dpo/loss-plot.pdf
|
|||||||
# Other
|
# Other
|
||||||
ch05/06_user_interface/chainlit.md
|
ch05/06_user_interface/chainlit.md
|
||||||
ch05/06_user_interface/.chainlit
|
ch05/06_user_interface/.chainlit
|
||||||
|
ch05/06_user_interface/.files
|
||||||
|
|
||||||
# Temporary OS-related files
|
# Temporary OS-related files
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|||||||
@ -17,7 +17,7 @@ To implement this user interface, we use the open-source [Chainlit Python packag
|
|||||||
|
|
||||||
First, we install the `chainlit` package via
|
First, we install the `chainlit` package via
|
||||||
|
|
||||||
```python
|
```bash
|
||||||
pip install chainlit
|
pip install chainlit
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,8 @@ from previous_chapters import (
|
|||||||
token_ids_to_text,
|
token_ids_to_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
def get_model_and_tokenizer():
|
def get_model_and_tokenizer():
|
||||||
"""
|
"""
|
||||||
@ -44,8 +46,6 @@ def get_model_and_tokenizer():
|
|||||||
|
|
||||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||||
|
|
||||||
gpt = GPTModel(BASE_CONFIG)
|
gpt = GPTModel(BASE_CONFIG)
|
||||||
@ -67,9 +67,9 @@ async def main(message: chainlit.Message):
|
|||||||
"""
|
"""
|
||||||
The main Chainlit function.
|
The main Chainlit function.
|
||||||
"""
|
"""
|
||||||
token_ids = generate(
|
token_ids = generate( # function uses `with torch.no_grad()` internally already
|
||||||
model=model,
|
model=model,
|
||||||
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
|
idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
|
||||||
max_new_tokens=50,
|
max_new_tokens=50,
|
||||||
context_size=model_config["context_length"],
|
context_size=model_config["context_length"],
|
||||||
top_k=1,
|
top_k=1,
|
||||||
|
|||||||
@ -17,6 +17,8 @@ from previous_chapters import (
|
|||||||
token_ids_to_text,
|
token_ids_to_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
def get_model_and_tokenizer():
|
def get_model_and_tokenizer():
|
||||||
"""
|
"""
|
||||||
@ -34,8 +36,6 @@ def get_model_and_tokenizer():
|
|||||||
"qkv_bias": False # Query-key-value bias
|
"qkv_bias": False # Query-key-value bias
|
||||||
}
|
}
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
model_path = Path("..") / "01_main-chapter-code" / "model.pth"
|
model_path = Path("..") / "01_main-chapter-code" / "model.pth"
|
||||||
@ -43,7 +43,7 @@ def get_model_and_tokenizer():
|
|||||||
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
|
print(f"Could not find the {model_path} file. Please run the chapter 5 code (ch05.ipynb) to generate the model.pth file.")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
checkpoint = torch.load("model.pth", weights_only=True)
|
checkpoint = torch.load(model_path, weights_only=True)
|
||||||
model = GPTModel(GPT_CONFIG_124M)
|
model = GPTModel(GPT_CONFIG_124M)
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -60,9 +60,9 @@ async def main(message: chainlit.Message):
|
|||||||
"""
|
"""
|
||||||
The main Chainlit function.
|
The main Chainlit function.
|
||||||
"""
|
"""
|
||||||
token_ids = generate(
|
token_ids = generate( # function uses `with torch.no_grad()` internally already
|
||||||
model=model,
|
model=model,
|
||||||
idx=text_to_token_ids(message.content, tokenizer), # The user text is provided via as `message.content`
|
idx=text_to_token_ids(message.content, tokenizer).to(device), # The user text is provided via as `message.content`
|
||||||
max_new_tokens=50,
|
max_new_tokens=50,
|
||||||
context_size=model_config["context_length"],
|
context_size=model_config["context_length"],
|
||||||
top_k=1,
|
top_k=1,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user