break long lines in ch07 (#263)

This commit is contained in:
Sebastian Raschka 2024-07-13 05:26:23 -07:00 committed by GitHub
parent 3549c2b0d4
commit 377c14c25e

View File

@ -48,11 +48,11 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"matplotlib version: 3.7.1\n", "matplotlib version: 3.8.4\n",
"tiktoken version: 0.7.0\n", "tiktoken version: 0.6.0\n",
"torch version: 2.3.0+cu121\n", "torch version: 2.2.2\n",
"tqdm version: 4.66.4\n", "tqdm version: 4.66.2\n",
"tensorflow version: 2.15.0\n" "tensorflow version: 2.16.1\n"
] ]
} }
], ],
@ -121,7 +121,7 @@
"source": [ "source": [
"- The topics covered in this chapter are summarized in the figure below\n", "- The topics covered in this chapter are summarized in the figure below\n",
"\n", "\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch07_compressed/chapter-overview-1.webp\" width=500px>" "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch07_compressed/chapter-overview-1.webp?123\" width=500px>"
] ]
}, },
{ {
@ -146,7 +146,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"id": "0G3axLw6kY1N", "id": "0G3axLw6kY1N",
"metadata": { "metadata": {
"colab": { "colab": {
@ -188,7 +188,10 @@
"\n", "\n",
"\n", "\n",
"file_path = \"instruction-data.json\"\n", "file_path = \"instruction-data.json\"\n",
"url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json\"\n", "url = (\n",
" \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n",
" \"/main/ch07/01_main-chapter-code/instruction-data.json\"\n",
")\n",
"\n", "\n",
"data = download_and_load_file(file_path, url)\n", "data = download_and_load_file(file_path, url)\n",
"print(\"Number of entries:\", len(data))" "print(\"Number of entries:\", len(data))"
@ -206,7 +209,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"id": "-LiuBMsHkzQV", "id": "-LiuBMsHkzQV",
"metadata": { "metadata": {
"colab": { "colab": {
@ -241,7 +244,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"id": "uFInFxDDk2Je", "id": "uFInFxDDk2Je",
"metadata": { "metadata": {
"colab": { "colab": {
@ -298,7 +301,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 6,
"id": "Jhk37nnJnkBh", "id": "Jhk37nnJnkBh",
"metadata": { "metadata": {
"id": "Jhk37nnJnkBh" "id": "Jhk37nnJnkBh"
@ -329,7 +332,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 7,
"id": "F9UQRfjzo4Js", "id": "F9UQRfjzo4Js",
"metadata": { "metadata": {
"colab": { "colab": {
@ -375,7 +378,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 8,
"id": "a3891fa9-f738-41cd-946c-80ef9a99c346", "id": "a3891fa9-f738-41cd-946c-80ef9a99c346",
"metadata": { "metadata": {
"colab": { "colab": {
@ -418,7 +421,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 9,
"id": "aFZVopbIlNfx", "id": "aFZVopbIlNfx",
"metadata": { "metadata": {
"id": "aFZVopbIlNfx" "id": "aFZVopbIlNfx"
@ -436,7 +439,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"id": "-zf6oht6bIUQ", "id": "-zf6oht6bIUQ",
"metadata": { "metadata": {
"colab": { "colab": {
@ -508,7 +511,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"id": "adc29dc4-f1c7-4c71-937b-95119d6239bb", "id": "adc29dc4-f1c7-4c71-937b-95119d6239bb",
"metadata": { "metadata": {
"id": "adc29dc4-f1c7-4c71-937b-95119d6239bb" "id": "adc29dc4-f1c7-4c71-937b-95119d6239bb"
@ -553,7 +556,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 12,
"id": "ff24fe1a-5746-461c-ad3d-b6d84a1a7c96", "id": "ff24fe1a-5746-461c-ad3d-b6d84a1a7c96",
"metadata": { "metadata": {
"colab": { "colab": {
@ -602,7 +605,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 15,
"id": "eb4c77dd-c956-4a1b-897b-b466909f18ca", "id": "eb4c77dd-c956-4a1b-897b-b466909f18ca",
"metadata": { "metadata": {
"id": "eb4c77dd-c956-4a1b-897b-b466909f18ca" "id": "eb4c77dd-c956-4a1b-897b-b466909f18ca"
@ -626,7 +629,10 @@
" new_item += [pad_token_id]\n", " new_item += [pad_token_id]\n",
" # Pad sequences to max_length\n", " # Pad sequences to max_length\n",
" # this always adds at least 1 additional padding tokens\n", " # this always adds at least 1 additional padding tokens\n",
" padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))\n", " padded = (\n",
" new_item + [pad_token_id] * \n",
" (batch_max_length - len(new_item))\n",
" )\n",
" # We remove this extra padded token again here\n", " # We remove this extra padded token again here\n",
" inputs = torch.tensor(padded[:-1])\n", " inputs = torch.tensor(padded[:-1])\n",
" inputs_lst.append(inputs)\n", " inputs_lst.append(inputs)\n",
@ -638,7 +644,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"id": "8fb02373-59b3-4f3a-b1d1-8181a2432645", "id": "8fb02373-59b3-4f3a-b1d1-8181a2432645",
"metadata": { "metadata": {
"colab": { "colab": {
@ -728,7 +734,10 @@
" # Add an <|endoftext|> token\n", " # Add an <|endoftext|> token\n",
" new_item += [pad_token_id]\n", " new_item += [pad_token_id]\n",
" # Pad sequences to max_length\n", " # Pad sequences to max_length\n",
" padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))\n", " padded = (\n",
" new_item + [pad_token_id] * \n",
" (batch_max_length - len(new_item))\n",
" )\n",
" inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n", " inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n",
" targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n", " targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n",
" inputs_lst.append(inputs)\n", " inputs_lst.append(inputs)\n",
@ -832,7 +841,10 @@
" # Add an <|endoftext|> token\n", " # Add an <|endoftext|> token\n",
" new_item += [pad_token_id]\n", " new_item += [pad_token_id]\n",
" # Pad sequences to max_length\n", " # Pad sequences to max_length\n",
" padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))\n", " padded = (\n",
" new_item + [pad_token_id] * \n",
" (batch_max_length - len(new_item))\n",
" )\n",
" inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n", " inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs\n",
" targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n", " targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets\n",
"\n", "\n",
@ -1132,7 +1144,11 @@
"source": [ "source": [
"from functools import partial\n", "from functools import partial\n",
"\n", "\n",
"customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)" "customized_collate_fn = partial(\n",
" custom_collate_fn,\n",
" device=device,\n",
" allowed_max_length=1024\n",
")"
] ]
}, },
{ {
@ -1535,7 +1551,10 @@
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n", "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n", "\n",
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n", "model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n", "settings, params = download_and_load_gpt2(\n",
" model_size=model_size, \n",
" models_dir=\"gpt2\"\n",
")\n",
"\n", "\n",
"model = GPTModel(BASE_CONFIG)\n", "model = GPTModel(BASE_CONFIG)\n",
"load_weights_into_gpt(model, params)\n", "load_weights_into_gpt(model, params)\n",
@ -1645,7 +1664,11 @@
} }
], ],
"source": [ "source": [
"response_text = generated_text[len(input_text):].strip()\n", "response_text = (\n",
" generated_text[len(input_text):]\n",
" .replace(\"### Response:\", \"\")\n",
" .strip()\n",
")\n",
"print(response_text)" "print(response_text)"
] ]
}, },
@ -2026,7 +2049,11 @@
" eos_id=50256\n", " eos_id=50256\n",
" )\n", " )\n",
" generated_text = token_ids_to_text(token_ids, tokenizer)\n", " generated_text = token_ids_to_text(token_ids, tokenizer)\n",
" response_text = generated_text[len(input_text):].replace(\"### Response:\", \"\").strip()\n", " response_text = (\n",
" generated_text[len(input_text):]\n",
" .replace(\"### Response:\", \"\")\n",
" .strip()\n",
")\n",
"\n", "\n",
" print(input_text)\n", " print(input_text)\n",
" print(f\"\\nCorrect response:\\n>> {entry['output']}\")\n", " print(f\"\\nCorrect response:\\n>> {entry['output']}\")\n",
@ -2416,7 +2443,11 @@
"source": [ "source": [
"import urllib.request\n", "import urllib.request\n",
"\n", "\n",
"def query_model(prompt, model=\"llama3\", url=\"http://localhost:11434/api/chat\"):\n", "def query_model(\n",
" prompt, \n",
" model=\"llama3\", \n",
" url=\"http://localhost:11434/api/chat\"\n",
"):\n",
" # Create the data payload as a dictionary\n", " # Create the data payload as a dictionary\n",
" data = {\n", " data = {\n",
" \"model\": model,\n", " \"model\": model,\n",
@ -2435,7 +2466,11 @@
" payload = json.dumps(data).encode(\"utf-8\")\n", " payload = json.dumps(data).encode(\"utf-8\")\n",
"\n", "\n",
" # Create a request object, setting the method to POST and adding necessary headers\n", " # Create a request object, setting the method to POST and adding necessary headers\n",
" request = urllib.request.Request(url, data=payload, method=\"POST\")\n", " request = urllib.request.Request(\n",
" url, \n",
" data=payload, \n",
" method=\"POST\"\n",
" )\n",
" request.add_header(\"Content-Type\", \"application/json\")\n", " request.add_header(\"Content-Type\", \"application/json\")\n",
"\n", "\n",
" # Send the request and capture the response\n", " # Send the request and capture the response\n",
@ -2730,7 +2765,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.11" "version": "3.11.4"
} }
}, },
"nbformat": 4, "nbformat": 4,