mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-30 01:10:33 +00:00
fixed typos (#414)
* fixed typos * fixed formatting * Update ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb * del weights after load into model --------- Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
This commit is contained in:
parent
8b60460319
commit
0ed1e0d099
@ -83,8 +83,8 @@
|
|||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"- To run all the code in this notebook, please ensure you update to at least PyTorch 2.5 (FlexAttention is not included in earlier PyTorch releases)\n",
|
"- To run all the code in this notebook, please ensure you update to at least PyTorch 2.5 (FlexAttention is not included in earlier PyTorch releases)\n",
|
||||||
"If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)\n",
|
"- If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)\n",
|
||||||
"- For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org."
|
"- For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -886,12 +886,14 @@
|
|||||||
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
|
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"- Set `need_weights` (default `True`) to need_weights=False so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
|
"- Set `need_weights` (default `True`) to `False` so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
|
"```markdown\n",
|
||||||
" Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
|
"need_weights: If specified, returns `attn_output_weights` in addition to `attn_outputs`.\n",
|
||||||
" and achieve the best performance for MHA.\n",
|
" Set `need_weights=False` to use the optimized `scaled_dot_product_attention`\n",
|
||||||
" Default: ``True``."
|
" and achieve the best performance for MHA.\n",
|
||||||
|
" Default: `True`\n",
|
||||||
|
"```"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1965,7 +1967,7 @@
|
|||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "pt",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -1979,7 +1981,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.11.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@ -1843,7 +1843,7 @@
|
|||||||
"id": "VlH7qYVdDKQr"
|
"id": "VlH7qYVdDKQr"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"- Note that the Llama 3 model should ideally used with the correct prompt template that was used during finetuning (as discussed in chapter 7)\n",
|
"- Note that the Llama 3 model should ideally be used with the correct prompt template that was used during finetuning (as discussed in chapter 7)\n",
|
||||||
"- Below is a wrapper class around the tokenizer based on Meta AI's Llama 3-specific [ChatFormat code](https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/tokenizer.py#L202) that constructs the prompt template"
|
"- Below is a wrapper class around the tokenizer based on Meta AI's Llama 3-specific [ChatFormat code](https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/tokenizer.py#L202) that constructs the prompt template"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -2099,7 +2099,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"LLAMA32_CONFIG[\"context_length\"] = 8192"
|
"LLAMA31_CONFIG_8B[\"context_length\"] = 8192"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2319,7 +2319,8 @@
|
|||||||
" combined_weights.update(current_weights)\n",
|
" combined_weights.update(current_weights)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"load_weights_into_llama(model, LLAMA31_CONFIG_8B, combined_weights)\n",
|
"load_weights_into_llama(model, LLAMA31_CONFIG_8B, combined_weights)\n",
|
||||||
"model.to(device);"
|
"model.to(device);\n",
|
||||||
|
"del combined_weights # free up memory"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2466,7 +2467,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"LLAMA32_CONFIG[\"context_length\"] = 8192"
|
"LLAMA32_CONFIG_1B[\"context_length\"] = 8192"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -2594,7 +2595,8 @@
|
|||||||
"current_weights = load_file(weights_file)\n",
|
"current_weights = load_file(weights_file)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"load_weights_into_llama(model, LLAMA32_CONFIG_1B, current_weights)\n",
|
"load_weights_into_llama(model, LLAMA32_CONFIG_1B, current_weights)\n",
|
||||||
"model.to(device);"
|
"model.to(device);\n",
|
||||||
|
"del current_weights # free up memory"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user