mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-13 16:01:34 +00:00
Set up basic test gh worklows (#79)
* Set up basic test gh worklows * update file paths * env check * add env check * Update requirements.txt * simplify * upd
This commit is contained in:
parent
9d6da22ebb
commit
ca96abac8a
41
.github/workflows/basic-tests.yml
vendored
Normal file
41
.github/workflows/basic-tests.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
name: Test Python Scripts and Notebooks
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- '**/*.py' # Run workflow for changes in Python files
|
||||||
|
- '**/*.ipynb' # Run workflow for changes in Jupyter notebooks
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- '**/*.py'
|
||||||
|
- '**/*.ipynb'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pytest nbval
|
||||||
|
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||||
|
|
||||||
|
- name: Test Selected Python Scripts
|
||||||
|
run: |
|
||||||
|
pytest ch04/01_main-chapter-code/tests.py
|
||||||
|
pytest appendix-A/02_installing-python-libraries/tests.py
|
||||||
|
|
||||||
|
- name: Validate Selected Jupyter Notebooks
|
||||||
|
run: |
|
||||||
|
pytest --nbval ch02/01_main-chapter-code/dataloader.ipynb
|
||||||
|
pytest --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
|
||||||
@ -11,13 +11,13 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"[OK] Your Python version is 3.10.12\n",
|
"[OK] Your Python version is 3.10.12\n",
|
||||||
"numpy >= 1.24.3 \n",
|
"[OK] numpy 1.26.0\n",
|
||||||
"matplotlib >= 3.7.1. \n",
|
"[OK] matplotlib 3.8.2\n",
|
||||||
"jupyterlab >= 4.0. \n",
|
"[OK] jupyterlab 4.0.6\n",
|
||||||
"tensorflow >= 2.15.0 \n",
|
"[OK] tensorflow 2.15.0\n",
|
||||||
"torch >= 2.0.1 \n",
|
"[OK] torch 2.2.1\n",
|
||||||
"tqdm >= 4.66.1 \n",
|
"[OK] tqdm 4.66.1\n",
|
||||||
"tiktoken >= 0.5.1 \n"
|
"[OK] tiktoken 0.5.1\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# Sebastian Raschka, 2024
|
# Sebastian Raschka, 2024
|
||||||
|
|
||||||
|
import importlib
|
||||||
from os.path import dirname, join, realpath
|
from os.path import dirname, join, realpath
|
||||||
from packaging.version import parse as version_parse
|
from packaging.version import parse as version_parse
|
||||||
import platform
|
import platform
|
||||||
@ -16,28 +17,21 @@ def get_packages(pkgs):
|
|||||||
versions = []
|
versions = []
|
||||||
for p in pkgs:
|
for p in pkgs:
|
||||||
try:
|
try:
|
||||||
imported = __import__(p)
|
imported = importlib.import_module(p)
|
||||||
try:
|
try:
|
||||||
versions.append(imported.__version__)
|
version = (getattr(imported, '__version__', None) or
|
||||||
except AttributeError:
|
getattr(imported, 'version', None) or
|
||||||
try:
|
getattr(imported, 'version_info', None))
|
||||||
versions.append(imported.version)
|
if version is None:
|
||||||
except AttributeError:
|
# If common attributes don't exist, use importlib.metadata
|
||||||
try:
|
version = importlib.metadata.version(p)
|
||||||
versions.append(imported.version_info)
|
versions.append(version)
|
||||||
except AttributeError:
|
except importlib.metadata.PackageNotFoundError:
|
||||||
try:
|
# Handle case where package is not installed
|
||||||
import importlib
|
versions.append('0.0')
|
||||||
import importlib_metadata
|
|
||||||
imported = importlib.import_module(p)
|
|
||||||
version = importlib_metadata.version(p)
|
|
||||||
versions.append(version)
|
|
||||||
except ImportError:
|
|
||||||
version = "not installed"
|
|
||||||
versions.append('0.0')
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(f'[FAIL]: {p} is not installed and/or cannot be imported.')
|
# Fallback if importlib.import_module fails for unexpected reasons
|
||||||
versions.append('N/A')
|
versions.append('0.0')
|
||||||
return versions
|
return versions
|
||||||
|
|
||||||
|
|
||||||
@ -50,11 +44,9 @@ def get_requirements_dict():
|
|||||||
for line in f:
|
for line in f:
|
||||||
if not line.strip():
|
if not line.strip():
|
||||||
continue
|
continue
|
||||||
line = line.split("#")[0]
|
line = line.split("#")[0].strip()
|
||||||
print(line)
|
|
||||||
line = line.split(" ")
|
line = line.split(" ")
|
||||||
if not line[0].strip() or not line[-1].strip():
|
line = [l.strip() for l in line]
|
||||||
continue
|
|
||||||
d[line[0]] = line[-1]
|
d[line[0]] = line[-1]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -72,6 +64,10 @@ def check_packages(d):
|
|||||||
print(f'[OK] {pkg_name} {actual_ver}')
|
print(f'[OK] {pkg_name} {actual_ver}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def main():
|
||||||
d = get_requirements_dict()
|
d = get_requirements_dict()
|
||||||
check_packages(d)
|
check_packages(d)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|||||||
7
appendix-A/02_installing-python-libraries/tests.py
Normal file
7
appendix-A/02_installing-python-libraries/tests.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from python_environment_check import main
|
||||||
|
|
||||||
|
|
||||||
|
def test_main(capsys):
|
||||||
|
main()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "FAIL" not in captured.out
|
||||||
@ -18,31 +18,6 @@
|
|||||||
"This notebook contains the main takeaway, the data loading pipeline without the intermediate steps."
|
"This notebook contains the main takeaway, the data loading pipeline without the intermediate steps."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "93804da5-372b-45ff-9ef4-8398ba1dd78e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"torch version: 2.0.1\n",
|
|
||||||
"tiktoken version: 0.5.1\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from importlib.metadata import version\n",
|
|
||||||
"\n",
|
|
||||||
"import tiktoken\n",
|
|
||||||
"import torch\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"torch version:\", version(\"torch\"))\n",
|
|
||||||
"print(\"tiktoken version:\", version(\"tiktoken\"))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
@ -164,7 +139,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@ -10,7 +10,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"id": "d70bae22-b540-4a13-ab01-e748cb9d55c9",
|
"id": "d70bae22-b540-4a13-ab01-e748cb9d55c9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -146,8 +146,8 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Fetching encoder.json: 1.04Mit [00:00, 3.03Mit/s] \n",
|
"Fetching encoder.json: 1.04Mit [00:00, 3.14Mit/s] \n",
|
||||||
"Fetching vocab.bpe: 457kit [00:00, 2.36Mit/s] \n"
|
"Fetching vocab.bpe: 457kit [00:00, 1.67Mit/s] \n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -215,25 +215,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 12,
|
||||||
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
"id": "e9077bf4-f91f-42ad-ab76-f3d89128510e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/sebastian/miniforge3/envs/book/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'4.33.2'"
|
"'4.34.0'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 13,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -246,10 +238,81 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 13,
|
||||||
"id": "a9839137-b8ea-4a2c-85fc-9a63064cf8c8",
|
"id": "a9839137-b8ea-4a2c-85fc-9a63064cf8c8",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "e4df871bb797435787143a3abe6b0231",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading tokenizer_config.json: 0%| | 0.00/26.0 [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "f11b27a4aabf43af9bf57f929683def6",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading vocab.json: 0%| | 0.00/1.04M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "d3aa9a24aacc43108ef2ed72e7bacd33",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "f9341bc23b594bb68dcf8954bff6d9bd",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading tokenizer.json: 0%| | 0.00/1.36M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "c5f55f2f1dbc4152acc9b2061167ee0a",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading config.json: 0%| | 0.00/665 [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from transformers import GPT2Tokenizer\n",
|
"from transformers import GPT2Tokenizer\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -258,7 +321,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 14,
|
||||||
"id": "222cbd69-6a3d-4868-9c1f-421ffc9d5fe1",
|
"id": "222cbd69-6a3d-4868-9c1f-421ffc9d5fe1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -268,7 +331,7 @@
|
|||||||
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]"
|
"[15496, 11, 995, 13, 1148, 428, 438, 257, 1332, 30]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 15,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -287,7 +350,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 15,
|
||||||
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
"id": "a61bb445-b151-4a2f-8180-d4004c503754",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -298,7 +361,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 16,
|
||||||
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
"id": "57f7c0a3-c1fd-4313-af34-68e78eb33653",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -306,7 +369,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"4.12 ms ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"4.29 ms ± 46.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -316,7 +379,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 17,
|
||||||
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
"id": "036dd628-3591-46c9-a5ce-b20b105a8062",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -324,7 +387,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"1.75 ms ± 8.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
|
"1.4 ms ± 9.71 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -334,7 +397,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 18,
|
||||||
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
"id": "b9c85b58-bfbc-465e-9a7e-477e53d55c90",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -349,7 +412,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"9.12 ms ± 856 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"8.46 ms ± 48.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -359,7 +422,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 19,
|
||||||
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
|
"id": "7117107f-22a6-46b4-a442-712d50b3ac7a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -367,7 +430,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"8.63 ms ± 247 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
"8.36 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -392,7 +455,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
1
ch02/02_bonus_bytepair-encoder/gpt2_model/encoder.json
Normal file
1
ch02/02_bonus_bytepair-encoder/gpt2_model/encoder.json
Normal file
File diff suppressed because one or more lines are too long
50001
ch02/02_bonus_bytepair-encoder/gpt2_model/vocab.bpe
Normal file
50001
ch02/02_bonus_bytepair-encoder/gpt2_model/vocab.bpe
Normal file
File diff suppressed because it is too large
Load Diff
@ -26,13 +26,12 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"torch version: 2.1.0\n"
|
"torch version: 2.2.1\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from importlib.metadata import version\n",
|
"from importlib.metadata import version\n",
|
||||||
"import torch\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"torch version:\", version(\"torch\"))"
|
"print(\"torch version:\", version(\"torch\"))"
|
||||||
]
|
]
|
||||||
@ -1992,7 +1991,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@ -8,6 +8,26 @@
|
|||||||
"# Multi-head Attention Plus Data Loading"
|
"# Multi-head Attention Plus Data Loading"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "ac9b5847-0515-45cd-87b0-46541f6a1f79",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch version: 2.2.1\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from importlib.metadata import version\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"torch version:\", version(\"torch\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "070000fc-a7b7-4c56-a2c0-a938d413a790",
|
"id": "070000fc-a7b7-4c56-a2c0-a938d413a790",
|
||||||
@ -28,7 +48,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 2,
|
||||||
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
|
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -96,7 +116,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
|
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -114,7 +134,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
|
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -148,7 +168,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 5,
|
||||||
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
|
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -196,7 +216,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 6,
|
||||||
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -235,7 +255,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 7,
|
||||||
"id": "2773c09d-c136-4372-a2be-04b58d292842",
|
"id": "2773c09d-c136-4372-a2be-04b58d292842",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -298,7 +318,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 8,
|
||||||
"id": "779fdd04-0152-4308-af08-840800a7f395",
|
"id": "779fdd04-0152-4308-af08-840800a7f395",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -324,14 +344,6 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"print(\"context_vecs.shape:\", context_vecs.shape)"
|
"print(\"context_vecs.shape:\", context_vecs.shape)"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "3ac01b16-8ac6-4487-a6f2-fd9cf33a9fe4",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -350,7 +362,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.10.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@ -235,8 +235,7 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
|||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
|
|
||||||
GPT_CONFIG_124M = {
|
GPT_CONFIG_124M = {
|
||||||
"vocab_size": 50257, # Vocabulary size
|
"vocab_size": 50257, # Vocabulary size
|
||||||
"ctx_len": 1024, # Context length
|
"ctx_len": 1024, # Context length
|
||||||
@ -274,3 +273,7 @@ if __name__ == "__main__":
|
|||||||
print("\nOutput:", out)
|
print("\nOutput:", out)
|
||||||
print("Output length:", len(out[0]))
|
print("Output length:", len(out[0]))
|
||||||
print("Output text:", decoded_text)
|
print("Output text:", decoded_text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
33
ch04/01_main-chapter-code/tests.py
Normal file
33
ch04/01_main-chapter-code/tests.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from gpt import main
|
||||||
|
|
||||||
|
expected = """
|
||||||
|
==================================================
|
||||||
|
IN
|
||||||
|
==================================================
|
||||||
|
|
||||||
|
Input text: Hello, I am
|
||||||
|
Encoded input text: [15496, 11, 314, 716]
|
||||||
|
encoded_tensor.shape: torch.Size([1, 4])
|
||||||
|
|
||||||
|
|
||||||
|
==================================================
|
||||||
|
OUT
|
||||||
|
==================================================
|
||||||
|
|
||||||
|
Output: tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267,
|
||||||
|
49706, 43231, 47062, 34657]])
|
||||||
|
Output length: 14
|
||||||
|
Output text: Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_main(capsys):
|
||||||
|
main()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
# Normalize line endings and strip trailing whitespace from each line
|
||||||
|
normalized_expected = '\n'.join(line.rstrip() for line in expected.splitlines())
|
||||||
|
normalized_output = '\n'.join(line.rstrip() for line in captured.out.splitlines())
|
||||||
|
|
||||||
|
# Compare normalized strings
|
||||||
|
assert normalized_output == normalized_expected
|
||||||
@ -1,6 +1,6 @@
|
|||||||
numpy >= 1.24.3 # ch05
|
numpy >= 1.24.3 # ch05
|
||||||
matplotlib >= 3.7.1. # ch04, ch05
|
matplotlib >= 3.7.1 # ch04, ch05
|
||||||
jupyterlab >= 4.0. # all
|
jupyterlab >= 4.0 # all
|
||||||
tensorflow >= 2.15.0 # ch05
|
tensorflow >= 2.15.0 # ch05
|
||||||
torch >= 2.0.1 # all
|
torch >= 2.0.1 # all
|
||||||
tqdm >= 4.66.1 # ch05
|
tqdm >= 4.66.1 # ch05
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user