2024-03-06 08:30:32 -06:00
{
2024-03-13 08:37:54 -05:00
"cells": [
2024-03-19 09:26:26 -05:00
{
"cell_type": "markdown",
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab",
2024-08-10 09:44:11 -05:00
"metadata": {
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab"
},
2024-03-19 09:26:26 -05:00
"source": [
2024-05-24 07:20:37 -05:00
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
2024-03-19 09:26:26 -05:00
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
},
"source": [
2024-03-23 07:27:43 -05:00
"# Comparing Efficient Multi-Head Attention Implementations"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "b742938a-4bfc-4527-a1f1-d5963508967d",
"metadata": {
"id": "b742938a-4bfc-4527-a1f1-d5963508967d"
},
"source": [
"This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
2024-08-10 09:44:11 -05:00
"outputId": "1d132538-9d44-4393-c2de-6947fee4c793"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"PyTorch version: 2.4.0\n"
2024-03-13 08:37:54 -05:00
]
2024-03-08 09:30:55 -06:00
}
2024-03-13 08:37:54 -05:00
],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"\n",
"batch_size = 8\n",
"context_len = 1024\n",
"embed_dim = 768\n",
"embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
]
},
{
"cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## 1) CausalAttention MHA wrapper class from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {
2024-03-09 10:20:08 -06:00
"colab": {
2024-03-13 08:37:54 -05:00
"base_uri": "https://localhost:8080/"
},
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
2024-08-10 09:44:11 -05:00
"outputId": "b63ea806-ef0d-4673-c36a-c22f0381e7ad"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n",
"\n",
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim//12,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03_wrapper(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "21930804-b327-40b1-8e63-94dcad39ce7b",
"metadata": {
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## 2) The multi-head attention class from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
2024-08-10 09:44:11 -05:00
"outputId": "79618eba-bc6f-4247-8d30-14381323a129"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"from ch03 import MultiHeadAttention as Ch03_MHA\n",
"\n",
"mha_ch03 = Ch03_MHA(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
"metadata": {
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## 3) An alternative multi-head attention with combined weights"
]
},
{
"cell_type": "markdown",
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
"metadata": {
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
},
"source": [
"- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
"- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
"\n",
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
"\n",
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
2024-08-10 09:44:11 -05:00
"outputId": "011ba527-4dcb-4909-ffa2-383b92b81c9f"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
2024-04-04 07:27:41 -05:00
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
2024-04-04 07:27:41 -05:00
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.head_dim = d_out // num_heads\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
2024-05-27 07:46:29 -05:00
" self.proj = nn.Linear(d_out, d_out)\n",
2024-03-13 08:37:54 -05:00
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" self.register_buffer(\n",
2024-04-04 07:27:41 -05:00
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
2024-03-13 08:37:54 -05:00
" )\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
2024-04-26 17:13:08 -05:00
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
" queries, keys, values = qkv.unbind(0)\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(-2, -1)\n",
" attn_scores = attn_scores.masked_fill(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
" )\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
" context_vec = attn_weights @ values\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
" context_vec = context_vec.transpose(1, 2)\n",
"\n",
" # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
2024-04-26 17:13:08 -05:00
" context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec\n",
"\n",
"\n",
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_combined_qkv(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
"metadata": {
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-08-10 09:44:11 -05:00
"## 4) Multihead attention with PyTorch's scaled dot product attention and FlashAttention"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "f78e346f-3b85-44e6-9feb-f01131381148",
"metadata": {
"id": "f78e346f-3b85-44e6-9feb-f01131381148"
},
"source": [
2024-08-10 09:44:11 -05:00
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention called [FlashAttention](https://arxiv.org/abs/2205.14135)"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
"metadata": {
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
},
"outputs": [],
"source": [
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
2024-04-04 07:27:41 -05:00
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
2024-04-04 07:27:41 -05:00
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
2024-05-26 15:38:35 -05:00
" self.proj = nn.Linear(d_out, d_out)\n",
2024-03-13 08:37:54 -05:00
" self.dropout = dropout\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
2024-04-26 17:13:08 -05:00
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
2024-04-26 17:13:08 -05:00
" queries, keys, values = qkv\n",
2024-03-13 08:37:54 -05:00
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
2024-08-10 09:44:11 -05:00
"\n",
2024-03-13 08:37:54 -05:00
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
2024-04-26 17:13:08 -05:00
" context_vec = self.proj(context_vec)\n",
"\n",
2024-03-13 08:37:54 -05:00
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
2024-08-10 09:44:11 -05:00
"outputId": "fcd0508d-f474-4f81-89e5-7e1dbd6305f4"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_scaled(embeddings)\n",
"print(out.shape)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "markdown",
"id": "51492724-6018-49f6-8bf6-ae9e585229c3",
"metadata": {
"id": "51492724-6018-49f6-8bf6-ae9e585229c3"
},
"source": [
"<br>\n",
" \n",
"\n",
"## 5) PyTorch's scaled dot product attention without FlashAttention\n",
"\n",
"- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b",
"metadata": {
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b"
},
"outputs": [],
"source": [
"class MHAPyTorchSDPAWithoutFlash(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
"\n",
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
"outputId": "841ee076-8f71-4223-85e0-9e74fc7ac2f4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_sdpa_no_flash = MHAPyTorchSDPAWithoutFlash(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_sdpa_no_flash(embeddings)\n",
"print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "351c318f-4835-4d74-8d58-a070222447c4",
"metadata": {
"id": "351c318f-4835-4d74-8d58-a070222447c4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-08-10 09:44:11 -05:00
"## 6) Using PyTorch's torch.nn.MultiheadAttention"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
"metadata": {
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
},
"source": [
"- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 9,
2024-03-13 08:37:54 -05:00
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
2024-08-10 09:44:11 -05:00
"outputId": "2f001121-6357-4ee2-9ba5-8ec1e687ea23"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MHAPyTorchClass(nn.Module):\n",
2024-04-04 07:27:41 -05:00
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
2024-04-04 07:27:41 -05:00
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.multihead_attn = nn.MultiheadAttention(\n",
" embed_dim=d_out,\n",
" num_heads=num_heads,\n",
" dropout=dropout,\n",
" bias=qkv_bias,\n",
" add_bias_kv=qkv_bias,\n",
" batch_first=True,\n",
" )\n",
"\n",
" self.need_weights = need_weights\n",
" self.proj = nn.Linear(d_out, d_out)\n",
2024-04-04 07:27:41 -05:00
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
2024-03-13 08:37:54 -05:00
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, _ = x.shape\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
2024-04-04 07:27:41 -05:00
" if self.context_length >= num_tokens:\n",
2024-03-13 08:37:54 -05:00
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
2024-04-04 07:27:41 -05:00
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
2024-03-13 08:37:54 -05:00
"\n",
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
" attn_output, _ = self.multihead_attn(\n",
" x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
" )\n",
"\n",
" output = self.proj(attn_output)\n",
"\n",
" return output\n",
"\n",
"\n",
"mha_pytorch_class_default = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_default(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
"metadata": {
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-08-10 09:44:11 -05:00
"## 7) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "d2164859-31a0-4537-b4fb-27d57675ba77",
"metadata": {
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
},
"source": [
"- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
"\n",
"> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
" Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
" and achieve the best performance for MHA.\n",
" Default: ``True``."
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 10,
2024-03-13 08:37:54 -05:00
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
2024-08-10 09:44:11 -05:00
"outputId": "311c1e2b-f437-4c4d-cbdc-1ecb0db3b78c"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False,\n",
" need_weights=False # NEW!\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_noweights(embeddings)\n",
"print(out.shape)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "markdown",
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc",
"metadata": {
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc"
},
"source": [
"<br>\n",
" \n",
"\n",
"## 8) Using PyTorch's FlexAttention\n",
"\n",
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
"- This is currently only supported in PyTorch 2.5 (nightly), which you can install on a CPU machine via\n",
"\n",
"```python\n",
"pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu -U\n",
"```\n",
"\n",
"- To install PyTorch nighly on a GPU machine, use the following (for more information, also see the installation menu on [pytorch.org](https://pytorch.org/))\n",
"\n",
"```python\n",
"pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n",
"```\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "834318c8-4748-4902-99f0-70ee02bef63e",
"metadata": {
"id": "834318c8-4748-4902-99f0-70ee02bef63e"
},
"outputs": [],
"source": [
"from packaging.version import parse as parse_version\n",
"\n",
"def normalize_version(version):\n",
" parsed_version = parse_version(version)\n",
" return parse_version(f\"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}\")\n",
"\n",
"current_version = normalize_version(torch.__version__)\n",
"MIN_TORCH_VERSION = \"2.5.0\"\n",
"required_version = parse_version(MIN_TORCH_VERSION)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "WYyFRCXndVH9",
"metadata": {
"id": "WYyFRCXndVH9"
},
"outputs": [],
"source": [
"if current_version >= required_version:\n",
" from torch.nn.attention.flex_attention import flex_attention\n",
" from torch.nn.attention.flex_attention import create_block_mask\n",
"\n",
"\n",
"def causal(b, h, q_idx, kv_idx):\n",
" return q_idx >= kv_idx\n",
"\n",
"\n",
"class MHAPyTorchFlexAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n",
" # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n",
" # `create_block_mask` function doesn not support buffers, yet\n",
" self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
"\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
"\n",
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
"outputId": "0888271f-d6ae-4905-fabb-805b70f9e712"
},
"outputs": [],
"source": [
"if current_version >= required_version:\n",
"\n",
" mha_pytorch_flex = MHAPyTorchFlexAttention(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
" ).to(device)\n",
"\n",
" out = mha_pytorch_flex(embeddings)\n",
" print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "8877de71-f84f-4f6d-bc87-7552013b6301",
"metadata": {
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## Quick speed comparison (M3 Macbook Air CPU)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 14,
"id": "219cf93a-078f-434d-888c-2458d0731285",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "219cf93a-078f-434d-888c-2458d0731285",
"outputId": "2bb41cd4-a152-4754-f361-94f9e17cf498"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.4.0\n",
"Running on cpu\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
2024-03-13 08:37:54 -05:00
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
2024-08-10 09:44:11 -05:00
"outputId": "67389a57-945b-42c9-e1a1-1e2a8b8cb710"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"196 ms ± 3.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 16,
2024-03-13 08:37:54 -05:00
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
2024-08-10 09:44:11 -05:00
"outputId": "930a14ee-36df-4a41-c162-374f9e1ea600"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"204 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 17,
2024-03-13 08:37:54 -05:00
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
2024-08-10 09:44:11 -05:00
"outputId": "801cf4e1-4f2a-44d8-e5d7-89fe0c62cee9"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"214 ms ± 8.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 18,
2024-03-13 08:37:54 -05:00
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
2024-08-10 09:44:11 -05:00
"outputId": "b9b43252-8942-46f4-84de-4b2889e3fb7e"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"81 ms ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 19,
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
"outputId": "a044e5f2-212e-45d5-ea40-4f4611077794"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"83.7 ms ± 2.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"## 5) PyTorch's scaled dot product attention without FlashAttention\n",
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 20,
2024-03-13 08:37:54 -05:00
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
2024-08-10 09:44:11 -05:00
"outputId": "c3f002da-b9b6-42d1-a24b-45789369c9a9"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"251 ms ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 6) Using PyTorch's torch.nn.MultiheadAttention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": 21,
2024-03-13 08:37:54 -05:00
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
"metadata": {
2024-08-10 09:44:11 -05:00
"colab": {
"base_uri": "https://localhost:8080/"
},
2024-03-13 08:37:54 -05:00
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
2024-08-10 09:44:11 -05:00
"outputId": "7d85c861-6e77-4d8b-8cb3-b05d8ddd243d"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"131 ms ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 7) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "code",
"execution_count": null,
"id": "5d7ee70d-7bdd-48ad-ad7f-af3bf2439609",
"metadata": {},
"outputs": [],
"source": [
"## 8) Using PyTorch's FlexAttention\n",
"\n",
"# Requires PyTorch 2.5.0 or newer\n",
"%timeit mha_pytorch_flex(embeddings)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9",
"metadata": {
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## Quick speed comparison (Nvidia A100 GPU)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
"id": "RStnI1pEi6Eo",
"metadata": {
"id": "RStnI1pEi6Eo"
},
"outputs": [],
"source": [
"# Cnable tensor cores\n",
"torch.set_float32_matmul_precision(\"high\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
"outputId": "5176759d-9599-4c8b-90bd-58cf8890ccde"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.5.0.dev20240810+cu121\n",
"Running on cuda\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
2024-08-10 09:44:11 -05:00
"outputId": "d0fe060f-058a-43ec-8d23-de23a94ed4cb"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"4.35 ms ± 30.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
2024-08-10 09:44:11 -05:00
"outputId": "b404c9e6-e0e0-4cb4-921f-fd612beb8668"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"3.09 ms ± 136 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
2024-08-10 09:44:11 -05:00
"outputId": "ef47787d-65bb-4407-d718-77dadb6f53d2"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"3.81 ms ± 189 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
2024-08-10 09:44:11 -05:00
"outputId": "58c8f320-79e2-400f-978c-1983fb810cf3"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"1.21 ms ± 875 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
2024-03-13 08:37:54 -05:00
]
2024-03-09 10:09:17 -06:00
}
2024-03-13 08:37:54 -05:00
],
"source": [
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
"outputId": "4e66760e-516a-43a5-a148-0afef96d23ca"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.96 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"## 5) PyTorch's scaled dot product attention without FlashAttention\n",
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
2024-08-10 09:44:11 -05:00
"outputId": "c17ddc03-0aab-4ef5-d3e3-a3695ec30b62"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"3.05 ms ± 225 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 6) Using PyTorch's torch.nn.MultiheadAttention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
2024-08-10 09:44:11 -05:00
"outputId": "e2781053-e7f4-4d87-99a1-08ca20f92951"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-10 09:44:11 -05:00
"2.29 ms ± 5.66 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 7) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "code",
"execution_count": null,
"id": "evKtpb5QN_2A",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "evKtpb5QN_2A",
"outputId": "a8af4a1c-e1c2-4a41-f454-cd395818a327"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.2 ms ± 587 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"## 8) Using PyTorch's FlexAttention\n",
"\n",
"# Requires PyTorch 2.5.0 or newer\n",
"%timeit mha_pytorch_flex(embeddings)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "dabc6575-0316-4640-a729-e616d5c17b73",
"metadata": {
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
"\n",
2024-08-10 09:44:11 -05:00
"## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
"outputId": "09bf4056-d219-4f90-f8f7-f31b62ffa671"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.5.0.dev20240810+cu121\n",
"Running on cuda\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000",
"metadata": {
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000"
},
"outputs": [],
"source": [
"# CUDA benchmark code shared by Andrei Aksionov\n",
"# and based on code from\n",
"# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n",
"\n",
"def time_pytorch_function(func, *input, num_repeats = 1_000):\n",
" # CUDA IS ASYNC so can't use python time module\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
"\n",
" # Warmup\n",
" for _ in range(5):\n",
" func(*input)\n",
" torch.cuda.synchronize()\n",
"\n",
" start.record()\n",
" for _ in range(num_repeats):\n",
" func(*input)\n",
" torch.cuda.synchronize()\n",
" end.record()\n",
" torch.cuda.synchronize()\n",
" return start.elapsed_time(end) / num_repeats"
]
},
{
"cell_type": "code",
2024-08-10 09:44:11 -05:00
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "CDJAPZaszaqx",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2024-08-10 09:44:11 -05:00
"height": 488
2024-03-13 08:37:54 -05:00
},
"id": "CDJAPZaszaqx",
2024-08-10 09:44:11 -05:00
"outputId": "86fef2d2-856b-4ae1-beb9-cc04b4ed6aad"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"data": {
2024-08-10 09:44:11 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1gU1/f48ffSpaggomLvYtDYe0GNvYCiYMGu2FBs2BBRrAj23nuLvfeusStGsXdFBcGKigTY3x/8mC+rJp8YyzJ4Xs/DE5mdXc7NzM6evXPvuRobGxstQgghhBBCdQz0HYAQQgghhPhvJJETQgghhFApSeSEEEIIIVRKEjkhhBBCCJWSRE4IIYQQQqUkkRNCCCGEUClJ5IQQQgghVEoSOSGEEEIIlTLSdwBqlCVLFqKjo/UdhhBCCCFSKUtLS548efI/95NE7gtlyZKFy5cv6zsMIYQQQqRyjo6O/zOZS3WJXPny5fHy8qJYsWJkzpyZ1q1bs2PHDuXx6dOn06JFC53n7N+/Hzc3t3/1+kk9cY6OjtIrJ4QQQohvztLSksuXL/+rPCPVJXLm5uaEhoaycuVKli5d+tl99u3bR8+ePZXfP3z48MV/Jzo6mjdv3vznOIUQQgghvlaqS+T279/P/v37/3Gf2NhYIiIiflBEQgghhBDfR6pL5P6NihUrcu3aNV69esXRo0cZPXo0L168+Oy+JiYmmJqaKr9bWlr+qDCFEEIIIf7RT5fI7d+/n23btnH//n1y587N0KFD+f3336lduzYJCQmf7N+7d28GDhyoh0iFEEIIIf7ZT5fIbdy4Ufn31atXCQ0N5fz581SqVIkjR458sv/kyZOZNWuW8nvSAEQhhBBCCH376QsC379/n8jISHLnzv3Zx2NjY3nz5o3yIzNVhRBCCJFS/PSJnL29PTY2NoSHh+s7FCGEEEKIL5LqEjkLCwscHR1xdHQEIEeOHDg6OpI1a1YsLCwYPnw4pUqVInv27FSpUoVly5Zx584dDhw4oOfIhRBCiJ9X+fLlWbFiBaGhoURFRVGvXj3lMSMjI/z9/Tl69CgPHjwgNDSUmTNnkjlz5n98TQMDAwYPHsz58+d59OgRZ8+epV+/fjr7TJ8+naioKJ2f33///bu08XtIdWPkihUrxpYtW5TfR48eDcCqVavo378/v/zyC82bNyddunQ8ffqUgwcPMnbsWGJjY/UVshBCCPHT+6c6sGnSpKFo0aIEBwcTGhpK+vTpGTNmDCtWrKBGjRp/+5re3t60b9+eHj16cO3aNYoVK8b06dN58+YNc+fOVfb7FvVl9SXVJXLHjx8nQ4YMf/t4s2bNfmA0QgghhPg3/qkO7Js3b3B1ddXZNnDgQPbt20fWrFkJCwv77PNKly7Nzp072bt3LwAPHz7E1dWVEiVK6Oyn5vqyqe7WqhBCCCFSv7Rp05KQkMDr16//dp8zZ85QpUoV8ubNC8Avv/xC2bJl2bdvn85+SfVlT506RXBwMNbW1t819m8p1fXICSGEECJ1MzU1ZdiwYaxfv/4fl8ucPHkyVlZWnDx5kvj4eAwNDRk9ejTr1q1T9vnS+rIpjSRyQgghhFANIyMjFixYgEajwcfH5x/3dXFxoWnTpnh6enLt2jWKFCnC6NGjefr0KatXrwa+vL5sSiOJnBBCCCFUwcjIiIULF5I9e3ZcXFz+sTcOYMSIEUyZMkVJ1q5evUr27Nnp3bu3ksh9LHl9WUnkhBBCCCG+gaQkLk+ePDg7O//tGunJpUmT5pPbo/Hx8Wg0mr99jtrqy0oiJ4QQQgi9s7Cw0FllKakO7IsXLwgPD2fx4sUULVqUFi1aYGhoiJ2dHQAvXrzgr7/+AhJvk27fvp358+cDsHv3bvr27cujR4+4du0aRYsWpVu3bqxcuVL5mz4+Pmzbto3w8HBy586Nv7+/qurLSiInhBBCCL37pzqwgYGB1K1bF+CT252NGjXi+PHjAOTKlQsbGxvlsUGDBjF48GCCgoKwtbXl6dOnLFmyhKCgICCxd07t9WU1NjY2Wn0HoSZWVlbcu3ePXLly/c9780IIIYQQX+pLcg2pIyeEEEIIoVKSyAkhhBBCqJQkckIIIYQQKpUiJjvkyJGD8uXLky1bNszNzYmMjOTSpUucOXNGVQvXCiGEEEL8SHpN5Jo2bUqXLl0oVqwYERERPH36lJiYGKytrcmVKxcfPnxg3bp1TJkyhUePHukzVCGEEEKIFEdvidzBgwf566+/WLVqFW3btuXx48c6j5uYmFC6dGkaN27M/v378fHx0ZmWLIQQQgjxs9Nb+ZFq1apx8ODBf7WvtbU1OXLk4OLFi985qv9Nyo8IIYQQ4nv6klxDrz1y/9aLFy/+1VIcQgghhBA/kxQxa7Vo0aI4ODgov9etW5dly5YxdOhQjI2N9RiZEEIIIUTKlSISuYkTJ5IvXz4AcubMybx583j37h2NGjVi+PDh+g1OCCGEECKFShHlR/LmzculS5cAcHZ25sSJE3Tp0oUyZcowf/58fH199RyhEEIIIb6WVeel+g7hm3kzr42+QwBSSI+cRqPBwCAxlKpVq7J3714AwsLCdBa/FUIIIYQQ/ydFJHIhISH069cPNzc3KlSooCRyOXPm5NmzZ3qOTgghhBAiZUoRidyQIUMoWrQogYGBTJw4kbt37wLQqFEjTp8+refohBBCCCFSphQxRu7KlStUrlz5k+3+/v7Ex8frISIhhBBCiJQvRSRyyVlYWCjj5ZJI4V0hhBBCiE+liEQuR44cBAYGUrFiRczMzJTtGo0GrVaLnZ2dHqMTQgghhEiZUkQiN3v2bDQaDb169eLZs2dotXpZNUwIIYQQQlVSRCL3yy+/UKNGDW7duqXvUIQQQgghVCNFzFq9cOECWbNm1XcYQgghhBCqkiJ65Hr37s2ECRPIkiULV69e5a+//tJ5/MqVK3qKTAghhBAi5UoRiZytrS25cuVi2rRpyjatViuTHYQQQggh/kGKSOSmTp3KpUuX8PT0JCIiQiY7CCGEEEL8CykikcuWLRutWrVSVnQQQgghhBD/W4qY7HD06FEcHR31HYYQQgghhKqkiB653bt3M2rUKBwcHD472WHXrl16ikwIIYQQIuVKEYnchAkTAPDx8fnkMZnsIIQQQgjxeSkikcuYMaO+QxBCCCGEUJ0UMUZOCCGEEEJ8Ob0lco0bN/7X+9rb21OmTJnvGI0QQgghhProLZFr3749J06coGfPnhQoUOCTx62srPjtt9+YM2cOBw8exMbGRg9RCiGEEEKkXHobI9eoUSPq1KlD586d8fPz4927d0RERPDhwwfSp0+PnZ0dUVFRrF69mkqVKvHs2TN9hSqEEEIIkSLpdbLDrl272LVrFzY2NpQrV45s2bKRJk0aoqKiuHTpEn/++aes8iCEEEII8TdSxKzV58+fs2PHDn2HIYQQQgihKjJrVQghhBBCpSSRE0IIIYRQKUnkhBBCCCFUShI5IYQQQgiVSlGJnLGxMfny5cPQ0FDfoQghhBBCpHgpIpFLkyYNU6ZM4dGjRxw/fpxs2bIBMG7cOLy9vfUcnRBCCCFEypQiEjk/Pz8cHR1p1KgRMTExyvbDhw/j4uLyRa9Vvnx5VqxYQWhoKFFRUdSrV++TfQYNGkRoaCiPHj1iw4YN5MmT52ubIIQQQgjxw6WIRK5evXoMHDiQU6dO6Wy/du0auXPn/qLXMjc3JzQ0lAEDBnz28V69euHp6Un//v2pVasW7969Y+3atZiamv7n+IUQQggh9CFFFATOkCHDZ5fgMjc3/+KVHfbv38/+/fv/9vEuXbowYcIEdu7cCUC3bt24du0a9erVY+PGjV8WuBBCCCGEHqWIHrmQkBBq1aql/J6UvLVu3ZozZ858s7+TM2dOMmfOzOHDh5Vtb9684dy5c5QuXfqzzzExMcHKykr5sbS0/GbxCCGEEEJ8jRTRIzdq1Ch+//13ChYsiKGhIV26dKFgwYKULl2aRo0afbO/Y2dnB/BJ79+zZ8+Uxz7Wu3dvBg4c+M1iEEIIIYT4VlJEj9ypU6eoWrUqhoaGXL16lWrVqhE
2024-03-13 08:37:54 -05:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"embeddings_cuda = embeddings.to(torch.device(\"cuda\"))\n",
"\n",
2024-08-10 09:44:11 -05:00
"\n",
2024-03-13 08:37:54 -05:00
"functions = {\n",
" \"1) MHA wrapper class\": mha_ch03_wrapper,\n",
" \"2) MHA Ch03\": mha_ch03,\n",
" \"3) MHA with combined QKV weights\": mha_combined_qkv,\n",
" \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n",
2024-08-10 09:44:11 -05:00
" \"5) PyTorch's SDPA, no FlashAttention\": mha_pytorch_sdpa_no_flash,\n",
" \"6) PyTorch MHA class defaults\": mha_pytorch_class_default,\n",
" \"7) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights,\n",
"\n",
2024-03-13 08:37:54 -05:00
"}\n",
2024-08-10 09:44:11 -05:00
"\n",
"if current_version >= required_version:\n",
" functions[\"8) PyTorch's FlexAttention\"] = mha_pytorch_flex\n",
"\n",
2024-03-13 08:37:54 -05:00
"execution_times = [time_pytorch_function(fn, embeddings_cuda) for name,fn in functions.items()]\n",
"\n",
"\n",
"# Plotting\n",
"\n",
"# Customize further for dark mode aesthetics\n",
2024-08-10 09:44:11 -05:00
"plt.rcParams[\"figure.facecolor\"] = \"#121212\" # Dark figure background\n",
"plt.rcParams[\"axes.facecolor\"] = \"#121212\" # Dark axes background\n",
"plt.rcParams[\"axes.edgecolor\"] = \"white\" # White axes border\n",
"plt.rcParams[\"axes.labelcolor\"] = \"white\" # White labels\n",
"plt.rcParams[\"text.color\"] = \"white\" # White text\n",
"plt.rcParams[\"xtick.color\"] = \"white\" # White x ticks\n",
"plt.rcParams[\"ytick.color\"] = \"white\" # White y ticks\n",
"plt.rcParams[\"grid.color\"] = \"#444444\" # Lighter grid lines for contrast\n",
"plt.rcParams[\"lines.linewidth\"] = 2 # Thicker plot lines for visibility\n",
"plt.rcParams[\"lines.markersize\"] = 8 # Larger markers for visibility\n",
"\n",
"fig, ax = plt.subplots()\n",
"bars = plt.bar(functions.keys(), execution_times)\n",
"\n",
"plt.ylabel(\"Execution time (ms)\")\n",
"plt.xticks(rotation=45, ha=\"right\")\n",
"\n",
"# Calculate new ylim with a margin\n",
"max_execution_time = max(execution_times)\n",
"upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n",
"\n",
"plt.ylim(0, upper_ylim) # Setting new ylim\n",
"\n",
"# Annotate bars with execution times\n",
"for bar in bars:\n",
" yval = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
"\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"1_forward-only.pdf\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "VQaSerWCOnYB",
"metadata": {
"id": "VQaSerWCOnYB"
},
"source": [
"<br>\n",
" \n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ReCmeRhCOpm8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 488
},
"id": "ReCmeRhCOpm8",
"outputId": "79daea84-6ca3-41d2-91c9-ecabb1a035bb"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVQU6xvA8e/SqYKIjdjixW4s1Ct2dzcWNtjYjd2dWNduVOzuRBEDDFQQTESk9veHh/mxxr3XK7osPp9zPLqzM7PP68zOPPvOGypra2s1QgghhBBC5+hpOwAhhBBCCPHfSCInhBBCCKGjJJETQgghhNBRksgJIYQQQugoSeSEEEIIIXSUJHJCCCGEEDpKEjkhhBBCCB0liZwQQgghhI4y0HYAuihjxoxERERoOwwhhBBCpFAWFhY8e/bsH9eTRO47ZcyYkZs3b2o7DCGEEEKkcI6Ojv+YzKWoRK5v377Url2b3Llz8+HDBy5cuMDo0aO5d++ess6OHTsoV66cxnYrVqzA3d39X31GQk2co6Oj1MoJIYQQIslZWFhw8+bNf5VnpKhEzsnJiWXLlnH58mUMDAwYPnw4mzdvxsnJicjISGW9VatWMWnSJOX1hw8fvvuzIiIiePfuXZLELYQQQgjxX6SoRK5p06Yar93c3AgICKBQoUKcOXNGWf7hwwdCQ0N/dXhCCCGEEEkqRfdaTZUqFQCvXr3SWN64cWMCAgI4efIknp6emJqafnMfRkZGWFpaKn8sLCx+asxCCCGEEP9WiqqRS0ylUjF+/HjOnj2Lv7+/snzLli08fvyY58+f88cffzBy5Ehy5cpFu3btvrqfvn37MmjQoF8VthBCCCHEv6aytrZWazuIn2Hq1KlUqVKFWrVq8fTp02+uV758ebZv306xYsUICgr64n0jIyOMjY2V1wkNEO3t7aWNnBBCCCGSnKWlJUFBQf8q10iRNXKTJ0/GxcWF2rVr/20SB3Dp0iUAsmfP/tVELjo6mujo6J8RphBCCCHED0lxidzkyZOpVasWdevW5dGjR/+4vqOjIwAhISE/OzQhhBBCiCSVojo7eHl50aRJE1xdXYmIiMDW1hZbW1tMTEwAsLe3Z8CAARQqVIisWbNSvXp15s+fz6lTp7h165aWoxdCCCF+X3379sXX15eHDx/i7+/PmjVryJUrl8Y6bdu2ZceOHQQFBREeHq50avw7HTp04Pjx4wQFBREUFISPjw9VqlTRWMfW1pYFCxZw69YtHj16xOHDh6lTp06Slu9nSVGJXMeOHUmdOjW7du3i9u3byp8GDRoAnx6TVqxYkc2bN3P27FnGjBnDrl27aNWqlZYjF0IIIX5vCWPBuri40KhRIwwMDNi8eTNmZmbKOqamphw+fJgZM2b86/0+ffqUMWPGULlyZapUqcKJEyfw9vYmb968yjrz588nV65ctG7dmvLly7Nnzx6WLVtGgQIFkrSMP0OK7ezws3xPA0QhhBBC/Ddp06YlICCA2rVra4wFC1C2bFl27txJ9uzZefv27Xfv+969e4wcOZK1a9cC8PDhQzw8PPjrr7+Ude7evcvo0aPx9vb+sYL8B9+Ta6SoGjkhhBBCpAzfGgv2R+jp6dGgQQPMzMy4ePGisvzChQvUr1+fNGnSoFKpaNCgAcbGxpw6dSrJPvtnSXGdHYQQQgih2741Fux/5eDggI+PDyYmJrx//562bdty584d5f2OHTuybNky7t+/T0xMDB8+fKBt27YEBgb+8Gf/bJLICSGEECJZ8fLywsHBgVq1aiXJ/u7du4ezszOpUqWibt26zJs3j7p16yrJ3NChQ0mdOjUNGjQgPDycmjVrsnz5cmrVqsXt27eTJIafRRI5IYQQQiQb3zMW7L8VExOj1K5du3aNIkWK4OrqyoABA7C3t6dLly44OTkpiZ2fnx9lypShU6dOuLu7J0kMP4skckIIIYRIFr53LNj/Sk9PT5m1KWG+dbVas+9nXFwcenrJvytB8o9QCCGEECneP40FC5/Ge3N0dCR79uwA5M+fH0dHR9KkSaOss23bNjp37qy89vT0pEyZMmTNmhUHBwc8PT0pW7YsmzdvBj71Tr1//z7Tpk2jaNGi2Nvb06NHD5ydndm7d++vKfwPkBo5IYQQQmhdx44dAdi1a5fGcjc3N9avXw9A+/btGTRokPLenj17vljH3t4ea2trZR0bGxvmz59P+vTpefv2Lbdu3aJJkyYcPXoUgNjYWJo3b86IESNYu3Yt5ubmBAYG0rNnT3x9fX9aeZOKjCP3nWQcOSGEEEL8TDKOnBBCCCHEb0ASOSGEEEIIHSWJnBBCCCGEjpJETgghhBBCRyWLXqt2dnaUKVOGLFmyYGZmRlhYGDdu3ODChQt8/PhR2+EJIYQQQiRLWk3kGjduTNeuXSlcuDChoaE8f/6cqKgorKyssLe35+PHj2zevJlZs2bx5MkTbYYqhBBCCJHsaC2RO3LkCDExMaxfv5527dp9MQ2HkZERJUqUoEGDBhw6dAgPDw927typpWiFEEIIIZIfrY0jV6lSJY4cOfKv1rWyssLOzo5r16795Kj+mYwjJ4QQQoif6XtyDa3WyP1br1694tWrVz8xGiGEEEII3ZMseq0WLFgQBwcH5XWNGjVYs2YNw4cPx9DQUIuRCSGEEEIkX8mi1+r06dOZNWsWt2/fJlu2bCxZsoQ9e/ZQt25dTE1NGTZsmLZDFEIIIcQPsuyyWtshJJl3S9pqOwQgmdTI5cyZkxs3bgBQr149zpw5Q9euXXFzc6NOnTpajk4IIYQQInlKFomcSqVCT+9TKBUrVuTgwYMABAcHY21trc3QhBBCCCGSrWSRyF29epUBAwbQtGlTnJyclEQuW7ZsvHjxQsvRCSGEEEIkT8kikRs6dCgFCxZk8uTJTJ8+ncDAQADq1q3L+fPntRydEEIIIUTylCw6O9y6dYvy5ct/sXzkyJHExcVpISIhhBBCiOQvWdTIJWZubo6lpSWWlpYYGRlhamr6r7ft27cvvr6+PHz4EH9/f9asWUOuXLk01jE2NmbKlCncvXuXhw8fsnLlStKlS5fUxRBCCCGE+OmSRSJnZ2fH+vXrefToEYGBgdy/f5/79+/z4MED7t+//6/34+TkxLJly3BxcaFRo0YYGBiwefNmzMzMlHXGjx9PtWrV6NixI3Xr1iVDhgysWrXqZxRLCCGEEOKnShaPVhcuXIhKpaJ37968ePECtfq/zRrWtGlTjddubm4EBARQqFAhzpw5g6WlJa1atcLV1ZUTJ04A0KtXL86ePUvx4sW5ePHiD5dFCCGEEOJXSRaJ3B9//EGVKlW4d+9eku43VapUAMr0XoULF8bIyIhjx44p69y9e5fHjx9/M5EzMjLC2NhYeW1hYZGkMQohhBBC/FfJ4tHqlStXyJw5c5LuU6VSMX78eM6ePYu/vz8Atra2fPz4kbdv32qs++LFC9KnT//V/fTt25egoCDlz82bN5M0TiGEEEKI/ypZ1Mj17duXadOmkTFjRm7fvk1MTIzG+7du3frufXp5eeHg4ECtWrV+KLaZM2eyYMEC5bWFhYUkc0IIIYRIFpJFImdjY4O9vT1z5sxRlqnValQqFWq1Gltb2+/a3+TJk3FxcaF27do8ffpUWR4aGoqxsTGpUqXSqJVLly4dISEhX91XdHQ00dHR31kiIYQQQoifL1kkcrNnz+bGjRu4uroSGhr6nzs7wKckrlatWtStW5dHjx5pvHf16lWio6OpWLEiu3btAiBXrlxkzZpVOjoIIYQQQucki0QuS5YstGrVSpnR4b/y8vKiUaNGtG7dmoiICKUm7+3bt0RFRfHu3TvWrl3L2LFjefXqFe/evWPSpEmcP39eEjkhhBBC6JxkkcidOHECR0fHH07kOnbsCKDUtiVwc3Nj/fr1AAwbNoz4+HhWrlyJkZERR44cwcPD44c+VwghhBBCG5JFIrd//37GjRuHg4PDVzs7+Pj4/Kv9pE2b9h/X+fjxIwMHDmTgwIH/KVYhhBBCiOQiWSRy06ZNA/hqzdh/6ewghBBCCPE7SBaJnMx1KoQQQgjx/ZL
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def forward_backward(func, embeddings):\n",
" if embeddings.grad is not None:\n",
" embeddings.grad.zero_()\n",
"\n",
" output = func(embeddings)\n",
" loss = output.sum()\n",
" loss.backward()\n",
"\n",
"\n",
"def time_pytorch_function_forward_backward(func, *input, num_repeats = 1_000):\n",
" # CUDA IS ASYNC so can't use python time module\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
"\n",
" # Warmup\n",
" for _ in range(5):\n",
" forward_backward(func, *input)\n",
" torch.cuda.synchronize()\n",
"\n",
" start.record()\n",
" for _ in range(num_repeats):\n",
" forward_backward(func, *input)\n",
" torch.cuda.synchronize()\n",
" end.record()\n",
" torch.cuda.synchronize()\n",
" return start.elapsed_time(end) / num_repeats\n",
"\n",
"\n",
"execution_times = [time_pytorch_function_forward_backward(fn, embeddings_cuda) for name,fn in functions.items()]\n",
"\n",
"\n",
"# Plotting\n",
"\n",
"\n",
"fig, ax = plt.subplots()\n",
"bars = plt.bar(functions.keys(), execution_times)\n",
"\n",
"plt.ylabel(\"Execution time (ms)\")\n",
"plt.xticks(rotation=45, ha=\"right\")\n",
"\n",
"# Calculate new ylim with a margin\n",
"max_execution_time = max(execution_times)\n",
"upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n",
"\n",
"plt.ylim(0, upper_ylim) # Setting new ylim\n",
"\n",
"# Annotate bars with execution times\n",
"for bar in bars:\n",
" yval = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
"\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"2_forward-and-backward.pdf\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "1gWX-Ayqia1k",
"metadata": {
"id": "1gWX-Ayqia1k"
},
"source": [
"<br>\n",
" \n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "LQDiAPooiYAz",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "LQDiAPooiYAz",
"outputId": "09d66064-0986-480e-ee39-c9a5faf7dcf5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] WON'T CONVERT forward <ipython-input-12-a390090d40bb> line 30 \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] due to: \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] Traceback (most recent call last):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1438, in _call_user_compiler\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn = compiler_fn(gm, self.example_inputs())\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py\", line 129, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_gm = compiler_fn(gm, example_inputs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/__init__.py\", line 2236, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return compile_fx(model_, inputs_, config_patches=self.config)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 1507, in compile_fx\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return aot_autograd(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py\", line 72, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cg = aot_module_simplified(gm, example_inputs, **self.kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 1033, in aot_module_simplified\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn = dispatch_and_compile()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 1022, in dispatch_and_compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn, _ = create_aot_dispatcher_function(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 435, in create_aot_dispatcher_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _create_aot_dispatcher_function(flat_fn, flat_args, aot_config)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 736, in _create_aot_dispatcher_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn, fw_metadata = compiler_fn(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py\", line 564, in aot_dispatch_autograd\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 1336, in fw_compiler_base\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _fw_compiler_base(model, example_inputs, is_inference)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 1407, in _fw_compiler_base\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return inner_compile(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 468, in compile_fx_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return wrap_compiler_debug(_compile_fx_inner, compiler_name=\"inductor\")(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_aot.py\", line 85, in debug_wrapper\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inner_compiled_fn = compiler_fn(gm, example_inputs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 653, in _compile_fx_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_graph = FxGraphCache.load(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py\", line 1319, in load\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_graph = compile_fx_fn(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 563, in codegen_and_compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 845, in fx_codegen_and_compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] graph.run(*example_inputs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 772, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return super().run(*args)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py\", line 147, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.env[node] = self.run_node(node)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 1280, in run_node\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] result = super().run_node(n)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py\", line 204, in run_node\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return getattr(self, n.op)(n.target, args, kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 1037, in call_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] raise LoweringException(e, target, args, kwargs).with_traceback(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 1034, in call_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] out = lowerings[target](*args, **kwargs) # type: ignore[index]\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/lowering.py\", line 323, in wrapped\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] out = decomp_fn(*args, **kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/kernel/flex_attention.py\", line 627, in flex_attention\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] query.get_stride(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/ir.py\", line 6151, in __getattr__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] fn = getattr(self.data, name)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch._inductor.exc.LoweringException: AttributeError: 'View' object has no attribute 'get_stride'\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] target: flex_attention\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[0]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[1]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=768),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[2]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=1536),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[4]: (TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_4])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_6])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_8])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf3, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf5, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_1,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_1, clone_1, sort])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf10, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type_2,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf12, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_3,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([clone_3, convert_element_type_3, sort_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[5]: 0.125\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[7]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[8]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] The above exception was the direct cause of the following exception:\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] Traceback (most recent call last):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 1039, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] result = self._inner_convert(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 514, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _compile(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 902, in _compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] guarded_code = compile_inner(code, one_graph, hooks, transform)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 653, in compile_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _compile_inner(code, one_graph, hooks, transform)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py\", line 85, in wrapper_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return StrobelightCompileTimeProfiler.profile_compile_time(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py\", line 129, in profile_compile_time\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return func(*args, **kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 686, in _compile_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] out_code = transform_code_object(code, transform)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py\", line 1322, in transform_code_object\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] transformations(instructions, code_options)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 208, in _fn\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return fn(*args, **kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 622, in transform\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tracer.run()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 2731, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] super().run()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 958, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] while self.step():\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 870, in step\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.dispatch_table[inst.opcode](self, inst)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 2922, in RETURN_VALUE\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self._return(inst)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 2907, in _return\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.output.compile_subgraph(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1109, in compile_subgraph\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1361, in compile_and_call_fx_graph\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn = self.call_user_compiler(gm)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1408, in call_user_compiler\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return self._call_user_compiler(gm)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1457, in _call_user_compiler\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] raise BackendCompilerFailed(self.compiler_fn, e) from e\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] LoweringException: AttributeError: 'View' object has no attribute 'get_stride'\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] target: flex_attention\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[0]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[1]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=768),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[2]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=1536),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[4]: (TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_4])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_6])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_8])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf3, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf5, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_1,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_1, clone_1, sort])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf10, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type_2,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf12, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_3,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([clone_3, convert_element_type_3, sort_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[5]: 0.125\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[7]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[8]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] Traceback (most recent call last):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1438, in _call_user_compiler\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn = compiler_fn(gm, self.example_inputs())\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py\", line 129, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_gm = compiler_fn(gm, example_inputs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/__init__.py\", line 2236, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return compile_fx(model_, inputs_, config_patches=self.config)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 1507, in compile_fx\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return aot_autograd(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py\", line 72, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cg = aot_module_simplified(gm, example_inputs, **self.kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 1033, in aot_module_simplified\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn = dispatch_and_compile()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 1022, in dispatch_and_compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn, _ = create_aot_dispatcher_function(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 435, in create_aot_dispatcher_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _create_aot_dispatcher_function(flat_fn, flat_args, aot_config)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py\", line 736, in _create_aot_dispatcher_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn, fw_metadata = compiler_fn(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py\", line 564, in aot_dispatch_autograd\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 1336, in fw_compiler_base\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _fw_compiler_base(model, example_inputs, is_inference)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 1407, in _fw_compiler_base\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return inner_compile(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 468, in compile_fx_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return wrap_compiler_debug(_compile_fx_inner, compiler_name=\"inductor\")(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_aot.py\", line 85, in debug_wrapper\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inner_compiled_fn = compiler_fn(gm, example_inputs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 653, in _compile_fx_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_graph = FxGraphCache.load(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py\", line 1319, in load\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_graph = compile_fx_fn(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 563, in codegen_and_compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py\", line 845, in fx_codegen_and_compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] graph.run(*example_inputs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 772, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return super().run(*args)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py\", line 147, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.env[node] = self.run_node(node)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 1280, in run_node\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] result = super().run_node(n)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py\", line 204, in run_node\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return getattr(self, n.op)(n.target, args, kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 1037, in call_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] raise LoweringException(e, target, args, kwargs).with_traceback(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py\", line 1034, in call_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] out = lowerings[target](*args, **kwargs) # type: ignore[index]\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/lowering.py\", line 323, in wrapped\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] out = decomp_fn(*args, **kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/kernel/flex_attention.py\", line 627, in flex_attention\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] query.get_stride(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_inductor/ir.py\", line 6151, in __getattr__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] fn = getattr(self.data, name)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch._inductor.exc.LoweringException: AttributeError: 'View' object has no attribute 'get_stride'\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] target: flex_attention\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[0]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[1]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=768),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[2]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=1536),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[4]: (TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_4])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_6])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_8])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf3, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf5, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_1,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_1, clone_1, sort])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf10, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type_2,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf12, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_3,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([clone_3, convert_element_type_3, sort_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[5]: 0.125\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[7]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[8]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] The above exception was the direct cause of the following exception:\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] Traceback (most recent call last):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 1039, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] result = self._inner_convert(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 514, in __call__\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _compile(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 902, in _compile\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] guarded_code = compile_inner(code, one_graph, hooks, transform)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 653, in compile_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return _compile_inner(code, one_graph, hooks, transform)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py\", line 85, in wrapper_function\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return StrobelightCompileTimeProfiler.profile_compile_time(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py\", line 129, in profile_compile_time\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return func(*args, **kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 686, in _compile_inner\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] out_code = transform_code_object(code, transform)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py\", line 1322, in transform_code_object\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] transformations(instructions, code_options)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 208, in _fn\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return fn(*args, **kwargs)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py\", line 622, in transform\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tracer.run()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 2731, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] super().run()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 958, in run\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] while self.step():\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 870, in step\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.dispatch_table[inst.opcode](self, inst)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 2922, in RETURN_VALUE\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self._return(inst)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py\", line 2907, in _return\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.output.compile_subgraph(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1109, in compile_subgraph\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1361, in compile_and_call_fx_graph\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] compiled_fn = self.call_user_compiler(gm)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1408, in call_user_compiler\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return self._call_user_compiler(gm)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] File \"/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py\", line 1457, in _call_user_compiler\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] raise BackendCompilerFailed(self.compiler_fn, e) from e\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] LoweringException: AttributeError: 'View' object has no attribute 'get_stride'\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] target: flex_attention\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[0]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[1]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=768),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[2]: TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] View(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ExternKernelOut(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name='extern_kernels.mm',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] name=buf0,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] layout=FixedLayout('cuda', torch.float32, size=[8192, 2304], stride=[2304, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] inputs=[ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[8, 1024, 768], stride=[786432, 768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[8192, 768], stride=[768, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[2304, 768], stride=[768, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[768, 2304], stride=[1, 768]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] constant_args=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwargs={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] output_view=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] python_kernel_name=extern_kernels.mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] cpp_kernel_name=at::mm_out,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ordered_kwargs_for_cpp_kernel=(),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] op_overload=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] arg_properties=[{}, {}],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] kwarg_properties=None,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] unbacked_bindings={},\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] mutation_outputs=[],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=mm,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([mm])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.float32, size=[1, 8, 12, 1024, 64], stride=[768, 2359296, 64, 2304, 1], offset=1536),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] size=[8, 12, 1024, 64],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] reindex=lambda i0, i1, i2, i3: [0, i0, i1, i2, i3],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([select_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[4]: (TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_4', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_4])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_5', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_6])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ReinterpretView(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]),\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([slice_8])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf3, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf5, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_1,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_1, clone_1, sort])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf13', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8], stride=[8, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf10, i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp1\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=convert_element_type_2,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([convert_element_type_2])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), TensorBox(StorageBox(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ComputedBuffer(name='buf14', layout=FixedLayout('cuda', torch.int32, size=[1, 1, 8, 8], stride=[64, 64, 8, 1]), data=Pointwise(\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] 'cuda',\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] torch.int32,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] def inner_fn(index):\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] _, _, i2, i3 = index\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp0 = ops.load(buf12, i3 + 8 * i2)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int16)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] return tmp2\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ranges=[1, 1, 8, 8],\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origin_node=clone_3,\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] origins=OrderedSet([clone_3, convert_element_type_3, sort_1])\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] ))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[5]: 0.125\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[7]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] args[8]: ()\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] Set TORCH_LOGS=\"+dynamo\" and TORCHDYNAMO_VERBOSE=1 for more information\n",
"W0810 13:55:08.918000 21035 torch/_dynamo/convert_frame.py:1100] \n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVhU6fv48ffQEiqoKKjYueraHWBgd2Bhi6IogoUitiJiJ7Yia6zd2N26YmNiYSCY2MD8/vDH+TLq7upn1fHA/bour3XOnDl7P54z59zzpMbGxkaLEEIIIYRQHQN9ByCEEEIIIf43ksgJIYQQQqiUJHJCCCGEEColiZwQQgghhEpJIieEEEIIoVKSyAkhhBBCqJQkckIIIYQQKiWJnBBCCCGEShnpOwA1srOzIzY2Vt9hCCGEECKZsrS05MGDB/+6nyRy38jOzo4LFy7oOwwhhBBCJHOFChX612ROErlvlFgTV6hQIamVE0IIIcR3Z2lpyYULF74qz5BE7n8UGxvLy5cv9R2GEEIIIVIwGewghBBCCKFSksgJIYQQQqiUJHJCCCGEEColiZwQQgghhEpJIieEEEIIoVKSyCUDZ86cISYm5rM/48eP/+L+rq6ubN68mRs3bnDjxg3Wrl1L8eLFdfaxsLAgICCA8+fPc+/ePY4cOUKHDh1+QmmEEEII8bVk+pFkoHr16hgaGiqvCxQowNq1a9mwYcMX969QoQJr167lxIkTvHv3jt69e7N69WoqVKigTDw4atQoKlWqRPfu3blz5w5OTk4EBgby8OFDQkNDf0q5hBBCCPHPpEYuGYiJiSEqKkr54+zszM2bNzl8+PAX9+/evTsLFy7kwoULXLt2DU9PTwwMDKhcubKyT+nSpVmxYgWHDx/m7t27BAcHc+HChc9q7oQQQgihPykukTMwMGDQoEH89ddf3Lt3j1OnTtG3b199h/XdGBsb07x5c5YtW/bVnzE3N8fIyIinT58q206cOEHt2rWxs7MDoGLFiuTOnZu9e/d+95iFEEII8b9JcU2rnp6edOzYkZ49exIeHk7RokWZMWMGL1++ZO7cufoO7z+rU6cOadKkYfny5V/9mWHDhvHw4UP279+vbPPx8WHy5MlcuHCBDx8+kJCQgJeXF0ePHv0RYQshhBDif5DiauRKlSrFtm3b2LlzJ3fv3mXTpk3s3bs32TQZtm3bll27dvHw4cOv2t/T05PGjRvTrl073r17p2zv2rUrJUuWpHXr1lStWpWhQ4cyfvx4qlSp8qNCF0II8YlvHcwG0KBBA44dO0ZkZCQHDx6kevXqOu8PGDCAY8eOcefOHWXAW4kSJX50UcQPkuISuZMnT1K5cmVy5coFwG+//UaZMmXYtWvXF/c3MTHByspK+WNpafkzw/0mWbJkoUqVKoSEhHzV/j179sTT05NmzZpx6dIlZbuZmRlDhgxhyJAhbN++nUuXLjF//nzWrVtHz549f1T4QgghPlG9enUKFCig/GnSpAnA3w5mK1WqFPPmzSMkJAQnJye2bt3K0qVLyZ8/v7LPjRs3GDhwIJUqVaJOnTrcuXOH1atXky5dup9SJvF9pbim1SlTpmBlZcWxY8eIj4/H0NCQMWPGsHr16i/u36dPHwYOHPiTo/zftG7dmsePH7Njx45/3bdXr154e3vTvHlzwsLCdN4zNjbGxMSEhIQEne3x8fEYGKS43F8IIfQmJiZG57Wnp+c/Dmbr1q0bu3fvZsaMGQD4+/vj6OhIly5d6NevHwBr1qzR+Yyfnx+urq789ttvHDhw4AeUQvxIKe6p3KhRI5o1a4abmxtOTk707NmTnj170rJlyy/uP2XKFLJnz678KVSo0E+O+OtoNBpat27NypUriY+P13lv1qxZ+Pn5Ka979+7NoEGD6N27N3fu3MHW1hZbW1ssLCwAePnyJYcOHWLEiBFUqFABBwcHWrVqhYuLC1u2bPmp5RJCCPHR1wxmK1WqlE5/Z4A9e/ZQqlSpvz1mu3bteP78ORcuXPiu8YqfI8XVyI0YMYKpU6eybt06AC5fvkzWrFnp06cPK1as+Gz/9+/f8/79+58d5jerUqUKWbNm5Y8//vjsvcyZM+vUrnXs2BFTU1MWL16ss19AQIDS76Jr1674+fkxZ84c0qZNy7179xgzZgyLFi36oeUQQgjxZV8zmM3W1pbHjx/rbHv8+DG2trY625ydnZk3bx7m5uY8evSIpk2b8uTJkx8St/ixUlwilypVqi82GWo0Gj1F9H3s27fvb/s3NGzYUOd1sWLF/vV4UVFR9OrV67vEJoQQ4r/71sFs/+TQoUM4OjqSLl06XF1dWbBgAc7OzkRHR3+HSMXPlOKaVrdv3463tzc1atQga9as1K1bF3d3d7Zu3arv0IQQQogv+trBbFFRUWTIkEFnW4YMGYiKitLZ9vr1ayIiIjh16hSenp7ExcXRtm3b7x63+PFSXCLn4+PDxo0bCQwM5OjRo4wYMYIlS5YwduxYfYcmhBBCfNHXDmZLnJkhKUdHR06ePPmPnzMwMMDExOQ/xyl+vhTXtBobG4uvry++vr76DkUIIYT4V/82mO3BgweMGjUKgDlz5rBp0yZ69OjBzp07ady4MUWLFsXLywv4uJKPt7c3oaGhPHz4kHTp0tG5c2fs7Oz+dkoT8WtLcYmcEEIIoSbfMpjt5MmTuLm54evry5AhQ7h58yaurq6Eh4cDH/uE58mTh5YtW2JjY8PTp085c+YM9erV48qVKz+tTOL70djY2Gj1HcTXcHBwoFy5cmTJkgVzc3Oio6M5f/48J0+e1FmR4EezsrLi1q1bZM+enZcvX/60/68QQgghUoZvyTV++Rq5Zs2a0a1bN4oWLUpUVBQPHz7k7du3WFtbkz17dt69e8fq1auZOnUq9+7d03e4QgghhBA/zS+dyO3du5cPHz6wfPly2rdvz/3793XeNzExoVSpUjRu3Jjdu3fTv39/Nm7cqKdohRBCCCF+rl+6adXJyYm9e/d+1b7W1tY4ODhw9uzZHxqTNK0KIYQQ4kdKNk2rX5vEATx9+pSnT5/+wGiEEEIIIX4tqplHrkiRIhQoUEB5Xbt2bZYuXcqQIUMwNjbWY2RCCCGEEPqhmkRu0qRJ5M6dG4Bs2bIxb948Xr9+TYMGDRg+fLh+gxNCCCGE0INfumk1qVy5cnH+/Hng49qhR48epVu3bpQuXZr58+cnuwl+rboG6zuE7+LlvHb6DkGIZMXOzo5hw4ZRrVo1UqVKRUREBL169SIsLOyL+9erV4+OHTtSqFAhTE1NCQ8PJyAgQKfryoABAxg4cKDO565du0bZsmV/ZFGEEN+BahI5jUaDgcHHCsQqVaqwfft2ACIjI7GxsdFnaEII8VOkSZOGrVu3cujQIVxcXIiOjiZnzpw8e/bsbz9Trlw59u3bx+jRo3n+/DmtW7dm2bJlODs7Kz+OAS5fvkyTJk2U13FxcT+yKEKI70Q1iVxYWBh9+/Zl//79lC9fnn79+gEfm1kfP36s5+iEEOLH8/T0JDIykl69einb7ty584+f+bS1YvTo0dSuXZuaNWvqJHJxcXGfLawuhPj1qaaP3ODBgylSpAgBAQFMmjSJiIgIABo0aMCJEyf0HJ0QQvx4tWrVIiwsjIULFxIeHs7evXtxdXX9pmNoNBosLS0/q8XLmTMnFy9e5PTp0wQFBZE5c+bvGLkQ4kdRTY3cpUuXqFSp0mfbhw0b9tkiwkIIkRxly5aNjh07Mnv2bCZPnkyxYsXw9/fnw4cPrFix4quO4eHhgYWFBevXr1e2nT59Gg8PD65fv07GjBkZMGAAW7ZsoWLFisTGxv6g0gghvgfVJHJJWVhYKP3lEsnkvEKI5M7AwICwsDBGjx4NwPnz5ylQoAAdOnT4qkSuadOm9O/fH1dXV6Kjo5Xtu3fvVv5+6dIlTp8+zdmzZ2nYsOEXF2oX/50MaBPfi2oSOQcHBwICAqhQoQJmZmbKdo1Gg1arxdbWVo/RCSHEj/fo0SOuXLmis+3q1avUr1//Xz/buHFjpkyZQqdOndi/f/8/7vvixQtu3LhBzpw5/1O8Qog
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch._dynamo\n",
"torch._dynamo.config.suppress_errors = True\n",
"\n",
"def prepare_function(fn):\n",
" fn = torch.compile(fn)\n",
" return fn\n",
"\n",
"\n",
"execution_times = [time_pytorch_function_forward_backward(prepare_function(fn), embeddings_cuda) for name,fn in functions.items()]\n",
"\n",
"\n",
"# Plotting\n",
2024-03-13 08:37:54 -05:00
"\n",
"fig, ax = plt.subplots()\n",
"bars = plt.bar(functions.keys(), execution_times)\n",
"\n",
2024-08-10 09:44:11 -05:00
"plt.ylabel(\"Execution time (ms)\")\n",
2024-03-13 08:37:54 -05:00
"plt.xticks(rotation=45, ha=\"right\")\n",
"\n",
"# Calculate new ylim with a margin\n",
"max_execution_time = max(execution_times)\n",
"upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n",
"\n",
"plt.ylim(0, upper_ylim) # Setting new ylim\n",
"\n",
"# Annotate bars with execution times\n",
"for bar in bars:\n",
" yval = bar.get_height()\n",
2024-08-10 09:44:11 -05:00
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
2024-03-13 08:37:54 -05:00
"\n",
"\n",
"plt.tight_layout()\n",
2024-08-10 09:44:11 -05:00
"plt.savefig(\"3_forward-and-backward-compiled.pdf\")\n",
"plt.show()"
2024-03-13 08:37:54 -05:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
2024-03-06 08:30:32 -06:00
},
2024-03-13 08:37:54 -05:00
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-04-26 17:13:08 -05:00
"version": "3.11.4"
2024-03-13 08:37:54 -05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}