mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 20:08:08 +00:00
Llama3 from scratch improvements (#621)
* Llama3 from scratch improvements * restore
This commit is contained in:
parent
1cbdcd86c3
commit
3eca919a52
@ -17,13 +17,16 @@ This folder contains code for converting the GPT implementation from chapter 4 a
|
|||||||
For an easy way to use the Llama 3.2 1B and 3B models, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
|
For an easy way to use the Llama 3.2 1B and 3B models, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
|
||||||
|
|
||||||
|
|
||||||
##### 1) Installation
|
#### 1) Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install llms_from_scratch blobfile
|
pip install llms_from_scratch blobfile
|
||||||
```
|
```
|
||||||
|
|
||||||
|
(Note that `blobfile` is needed to load the tokenizer.)
|
||||||
|
|
||||||
|
|
||||||
##### 2) Model and text generation settings
|
#### 2) Model and text generation settings
|
||||||
|
|
||||||
Specify which model to use:
|
Specify which model to use:
|
||||||
|
|
||||||
@ -51,7 +54,7 @@ TOP_K = 1
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
##### 3) Weight download and loading
|
#### 3) Weight download and loading
|
||||||
|
|
||||||
This automatically downloads the weight file based on the model choice above:
|
This automatically downloads the weight file based on the model choice above:
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ else:
|
|||||||
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
|
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
|
||||||
|
|
||||||
model = Llama3Model(LLAMA32_CONFIG)
|
model = Llama3Model(LLAMA32_CONFIG)
|
||||||
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
|
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))
|
||||||
|
|
||||||
device = (
|
device = (
|
||||||
torch.device("cuda") if torch.cuda.is_available() else
|
torch.device("cuda") if torch.cuda.is_available() else
|
||||||
@ -93,7 +96,7 @@ model.to(device)
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
##### 4) Initialize tokenizer
|
#### 4) Initialize tokenizer
|
||||||
|
|
||||||
The following code downloads and initializes the tokenizer:
|
The following code downloads and initializes the tokenizer:
|
||||||
|
|
||||||
@ -115,14 +118,14 @@ if "instruct" in MODEL_FILE:
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
##### 5) Generating text
|
#### 5) Generating text
|
||||||
|
|
||||||
Lastly, we can generate text via the following code:
|
Lastly, we can generate text via the following code:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from llms_from_scratch.ch05 import (
|
from ch05 import (
|
||||||
generate,
|
generate,
|
||||||
text_to_token_ids,
|
text_to_token_ids,
|
||||||
token_ids_to_text
|
token_ids_to_text
|
||||||
@ -141,7 +144,9 @@ token_ids = generate(
|
|||||||
temperature=TEMPERATURE
|
temperature=TEMPERATURE
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Time: {time.time() - start:.2f} sec")
|
total_time = time.time() - start
|
||||||
|
print(f"Time: {total_time:.2f} sec")
|
||||||
|
print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
max_mem_bytes = torch.cuda.max_memory_allocated()
|
max_mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
@ -159,7 +164,8 @@ print("\n\nOutput text:\n\n", output_text)
|
|||||||
When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:
|
When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:
|
||||||
|
|
||||||
```
|
```
|
||||||
Time: 4.12 sec
|
Time: 3.17 sec
|
||||||
|
50 tokens/sec
|
||||||
Max memory allocated: 2.91 GB
|
Max memory allocated: 2.91 GB
|
||||||
|
|
||||||
|
|
||||||
@ -176,7 +182,22 @@ It's worth noting that the specific diet of llamas can vary depending on factors
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
**Pro tip**
|
#### Pro tip 1: speed up inference with FlashAttention
|
||||||
|
|
||||||
|
Instead of using `Llama3Model`, you can use `Llama3ModelFast` as a drop-in replacement. For more information, I encourage you to inspect the [pkg/llms_from_scratch/llama3.py](../../pkg/llms_from_scratch/llama3.py) code.
|
||||||
|
|
||||||
|
The `Llama3ModelFast` replaces my from-scratch scaled dot-product code in the `GroupedQueryAttention` module with PyTorch's `scaled_dot_product` function, which uses `FlashAttention` on Ampere GPUs or newer.
|
||||||
|
|
||||||
|
The following table shows a performance comparison on an A100:
|
||||||
|
|
||||||
|
| | Tokens/sec | Memory |
|
||||||
|
| --------------- | ---------- | ------- |
|
||||||
|
| Llama3Model | 50 | 2.91 GB |
|
||||||
|
| Llama3ModelFast | 58 | 2.85 GB |
|
||||||
|
|
||||||
|
|
||||||
|
#### Pro tip 2: speed up inference with compilation
|
||||||
|
|
||||||
|
|
||||||
For up to a 4× speed-up, replace
|
For up to a 4× speed-up, replace
|
||||||
|
|
||||||
@ -191,5 +212,11 @@ model = torch.compile(model)
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: the speed-up takes effect after the first `generate` call.
|
Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call.
|
||||||
|
|
||||||
|
The following table shows a performance comparison on an A100 for consequent `generate` calls:
|
||||||
|
|
||||||
|
| | Tokens/sec | Memory |
|
||||||
|
| --------------- | ---------- | ------- |
|
||||||
|
| Llama3Model | 156 | 3.12 GB |
|
||||||
|
| Llama3ModelFast | 159 | 2.84 GB |
|
||||||
|
Loading…
x
Reference in New Issue
Block a user