diff --git a/ch05/01_main-chapter-code/tests.py b/ch05/01_main-chapter-code/tests.py index 4ffcb1e..97129f1 100644 --- a/ch05/01_main-chapter-code/tests.py +++ b/ch05/01_main-chapter-code/tests.py @@ -7,6 +7,8 @@ import pytest from gpt_train import main +import http.client +from urllib.parse import urlparse @pytest.fixture @@ -38,3 +40,59 @@ def test_main(gpt_config, other_settings): assert len(train_losses) == 39, "Unexpected number of training losses" assert len(val_losses) == 39, "Unexpected number of validation losses" assert len(tokens_seen) == 39, "Unexpected number of tokens seen" + + +def check_file_size(url, expected_size): + parsed_url = urlparse(url) + if parsed_url.scheme == "https": + conn = http.client.HTTPSConnection(parsed_url.netloc) + else: + conn = http.client.HTTPConnection(parsed_url.netloc) + + conn.request("HEAD", parsed_url.path) + response = conn.getresponse() + if response.status != 200: + return False, f"{url} not accessible" + size = response.getheader("Content-Length") + if size is None: + return False, "Content-Length header is missing" + size = int(size) + if size != expected_size: + return False, f"{url} file has expected size {expected_size}, but got {size}" + return True, f"{url} file size is correct" + + +def test_model_files(): + base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models" + + model_size = "124M" + files = { + "checkpoint": 77, + "encoder.json": 1042301, + "hparams.json": 90, + "model.ckpt.data-00000-of-00001": 497759232, + "model.ckpt.index": 5215, + "model.ckpt.meta": 471155, + "vocab.bpe": 456318 + } + + for file_name, expected_size in files.items(): + url = f"{base_url}/{model_size}/{file_name}" + valid, message = check_file_size(url, expected_size) + assert valid, message + + model_size = "355M" + files = { + "checkpoint": 77, + "encoder.json": 1042301, + "hparams.json": 91, + "model.ckpt.data-00000-of-00001": 1419292672, + "model.ckpt.index": 10399, + "model.ckpt.meta": 926519, + "vocab.bpe": 456318 + } + + for file_name, expected_size in files.items(): + url = f"{base_url}/{model_size}/{file_name}" + valid, message = check_file_size(url, expected_size) + assert valid, message