From c315213b8e801d7596d48731f428145b50c89b57 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 20 Mar 2024 08:18:25 -0700 Subject: [PATCH] Update zero to thunder to new extensibility example (PR2488) --- notebooks/zero_to_thunder.ipynb | 4258 +++++++++++++++++++++++-------- 1 file changed, 3264 insertions(+), 994 deletions(-) diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index 1c536f0cde..a1a888cc72 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -3,7 +3,11 @@ { "cell_type": "markdown", "id": "1638964c", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "# Zero to Thunder\n", "\n", @@ -21,16 +25,18 @@ "source": [ "import sys\n", "sys.path.insert(0, '..')\n", - "import inspect\n", - "\n", "\n", - "import torch, thunder\n" + "import torch, thunder" ] }, { "cell_type": "markdown", "id": "54f87aba", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Compiling a first module with Thunder\n", "\n", @@ -40,7 +46,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "d6ca6328", + "id": "892be718", "metadata": {}, "outputs": [ { @@ -62,26 +68,26 @@ " self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", " self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)\n", " self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)\n", - "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " x_fc_1 = self.fc_1(x)\n", " x_fc_2 = self.fc_2(x)\n", " x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n", " return self.proj(x)\n", - "\n", - "\n", "with torch.device(\"cuda\"):\n", " m = LLaMAMLP(4096, 11008)\n", "for p in m.parameters():\n", " p.requires_grad_(False)\n", - "\n", - "print(m)" + "print(m)\n" ] }, { "cell_type": "markdown", "id": "702ea054", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`" ] @@ -125,8 +131,12 @@ }, { "cell_type": "markdown", - "id": "59db20f6", - "metadata": {}, + "id": "47d24f2d-0e89-4fe8-8154-9b50f2633e1b", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance." ] @@ -135,15 +145,19 @@ "cell_type": "code", "execution_count": 5, "id": "7f4de1b3", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "deviation: 1.4901161193847656e-07\n", - "58.2 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "58.7 ms ± 50.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", + "62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -157,20 +171,25 @@ }, { "cell_type": "markdown", - "id": "8835543e", - "metadata": {}, + "id": "7996acc7-de20-4aa5-80f0-1ab6042e2650", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "So what has changed?\n", - "Quite a bit!\n", + "So what has changed? Quite a bit!\n", "\n", - "When we call the Thunder module, it does the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" + "When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:" ] }, { "cell_type": "code", "execution_count": 6, "id": "a6f4b77c", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { @@ -221,8 +240,12 @@ }, { "cell_type": "markdown", - "id": "a0071924", - "metadata": {}, + "id": "2ef89186-70cd-4737-9695-ed282da2a56c", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, "source": [ "For more detail of what is going on in this trace:\n", "- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.\n", @@ -237,13 +260,17 @@ { "cell_type": "markdown", "id": "7749aed1", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Compiling a more complex model\n", "\n", - "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller model here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", + "Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):\n", "\n", - "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt" + "**NOTE**: For running the cells below, we require `litgpt` which can be installed with `pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'`. See [here](https://github.com/Lightning-AI/litgpt) to learn more about litgpt." ] }, { @@ -260,7 +287,7 @@ " (transformer): ModuleDict(\n", " (wte): Embedding(32000, 4096)\n", " (h): ModuleList(\n", - " (0-3): 4 x Block(\n", + " (0-15): 16 x Block(\n", " (norm_1): RMSNorm()\n", " (attn): CausalSelfAttention(\n", " (attn): Linear(in_features=4096, out_features=12288, bias=False)\n", @@ -288,7 +315,8 @@ "from lit_gpt import GPT\n", "from thunder.tests.lit_gpt_model import Config\n", "cfg = Config.from_name('Llama-2-7b-hf')\n", - "cfg.n_layer = 4 # fewer layers\n", + "cfg.n_layer = 16 # fewer layers\n", + "torch.set_default_dtype(torch.bfloat16)\n", "with torch.device('cuda'):\n", " m = GPT(cfg)\n", "m\n" @@ -312,7 +340,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "deviation: 1.8477439880371094e-06\n" + "deviation: 0.03125\n" ] } ], @@ -329,22 +357,37 @@ }, { "cell_type": "markdown", - "id": "2f681093", - "metadata": {}, + "id": "9947e8df-cd2d-447d-90b9-ee08bb5a9fb2", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "Just like before, we can see the program it ran:" + "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.\n", + "\n", + "Just like before, we can see the program it ran, it is a lot longer, though." ] }, { "cell_type": "code", "execution_count": 9, "id": "ac7e8bc9", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, { "data": { "text/plain": [ - "# Constructed by Delete Last Used (took 1 milliseconds)\n", + "# Constructed by Delete Last Used (took 10 milliseconds)\n", "import torch\n", "from torch import Tensor\n", "import torch.nn.functional\n", @@ -388,626 +431,2728 @@ " t31, \\\n", " t32, \\\n", " t33, \\\n", + " t34, \\\n", + " t35, \\\n", + " t36, \\\n", + " t37, \\\n", + " t38, \\\n", + " t39, \\\n", + " t40, \\\n", + " t41, \\\n", + " t42, \\\n", + " t43, \\\n", + " t44, \\\n", + " t45, \\\n", + " t46, \\\n", + " t47, \\\n", + " t48, \\\n", + " t49, \\\n", + " t50, \\\n", + " t51, \\\n", + " t52, \\\n", + " t53, \\\n", + " t54, \\\n", + " t55, \\\n", + " t56, \\\n", + " t57, \\\n", + " t58, \\\n", + " t59, \\\n", + " t60, \\\n", + " t61, \\\n", + " t62, \\\n", + " t63, \\\n", + " t64, \\\n", + " t65, \\\n", + " t66, \\\n", + " t67, \\\n", + " t68, \\\n", + " t69, \\\n", + " t70, \\\n", + " t71, \\\n", + " t72, \\\n", + " t73, \\\n", + " t74, \\\n", + " t75, \\\n", + " t76, \\\n", + " t77, \\\n", + " t78, \\\n", + " t79, \\\n", + " t80, \\\n", + " t81, \\\n", + " t82, \\\n", + " t83, \\\n", + " t84, \\\n", + " t85, \\\n", + " t86, \\\n", + " t87, \\\n", + " t88, \\\n", + " t89, \\\n", + " t90, \\\n", + " t91, \\\n", + " t92, \\\n", + " t93, \\\n", + " t94, \\\n", + " t95, \\\n", + " t96, \\\n", + " t97, \\\n", + " t98, \\\n", + " t99, \\\n", + " t100, \\\n", + " t101, \\\n", + " t102, \\\n", + " t103, \\\n", + " t104, \\\n", + " t105, \\\n", + " t106, \\\n", + " t107, \\\n", + " t108, \\\n", + " t109, \\\n", + " t110, \\\n", + " t111, \\\n", + " t112, \\\n", + " t113, \\\n", + " t114, \\\n", + " t115, \\\n", + " t116, \\\n", + " t117, \\\n", " = args\n", " del args\n", - " t38 = torch.nn.functional.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t38 = ltorch.embedding(t0, t33, None, None, 2.0, False, False) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t334 = ltorch.reshape(t0, [512]) # t334: \"cuda:0 i64[512]\"\n", - " # t334 = prims.reshape(t0, (512,)) # t334: \"cuda:0 i64[512]\"\n", - " # t335 = prims.take(t33, t334, 0) # t335: \"cuda:0 f32[512, 4096]\"\n", - " # t38 = ltorch.reshape(t335, [1, 512, 4096]) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t38 = prims.reshape(t335, (1, 512, 4096)) # t38: \"cuda:0 f32[1, 512, 4096]\"\n", - " t34 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t34: \"cuda:0 f32[512, 128]\"\n", - " t35 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t35: \"cuda:0 f32[512, 128]\"\n", - " t374 = torch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n", - " # t374 = ltorch.unsqueeze(t17, 0) # t374: \"cuda:0 f32[1, 4096]\"\n", - " # t374 = prims.broadcast_in_dim(t17, [1, 4096], [1]) # t374: \"cuda:0 f32[1, 4096]\"\n", - " t375 = torch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t375 = ltorch.unsqueeze(t374, 1) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t375 = prims.broadcast_in_dim(t374, [1, 1, 4096], [0, 2]) # t375: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t374\n", - " t47 = Tensor.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t47 = ltorch.expand(t375, (1, 512, 4096)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t47 = prims.broadcast_in_dim(t375, (1, 512, 4096), (0, 1, 2)) # t47: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t375\n", - " t475 = torch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n", - " # t475 = ltorch.unsqueeze(t24, 0) # t475: \"cuda:0 f32[1, 4096]\"\n", - " # t475 = prims.broadcast_in_dim(t24, [1, 4096], [1]) # t475: \"cuda:0 f32[1, 4096]\"\n", - " t476 = torch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t476 = ltorch.unsqueeze(t475, 1) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t476 = prims.broadcast_in_dim(t475, [1, 1, 4096], [0, 2]) # t476: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t475\n", - " t311 = Tensor.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t311 = ltorch.expand(t476, (1, 512, 4096)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t311 = prims.broadcast_in_dim(t476, (1, 512, 4096), (0, 1, 2)) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t476\n", - " t478 = torch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n", - " # t478 = ltorch.unsqueeze(t16, 0) # t478: \"cuda:0 f32[1, 4096]\"\n", - " # t478 = prims.broadcast_in_dim(t16, [1, 4096], [1]) # t478: \"cuda:0 f32[1, 4096]\"\n", - " t479 = torch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t479 = ltorch.unsqueeze(t478, 1) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t479 = prims.broadcast_in_dim(t478, [1, 1, 4096], [0, 2]) # t479: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t478\n", - " t331 = Tensor.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t331 = ltorch.expand(t479, (1, 512, 4096)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t331 = prims.broadcast_in_dim(t479, (1, 512, 4096), (0, 1, 2)) # t331: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t479\n", - " t403 = torch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n", - " # t403 = ltorch.unsqueeze(t21, 0) # t403: \"cuda:0 f32[1, 4096]\"\n", - " # t403 = prims.broadcast_in_dim(t21, [1, 4096], [1]) # t403: \"cuda:0 f32[1, 4096]\"\n", - " t404 = torch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t404 = ltorch.unsqueeze(t403, 1) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t404 = prims.broadcast_in_dim(t403, [1, 1, 4096], [0, 2]) # t404: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t403\n", - " t98 = Tensor.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t98 = ltorch.expand(t404, (1, 512, 4096)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t98 = prims.broadcast_in_dim(t404, (1, 512, 4096), (0, 1, 2)) # t98: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t404\n", - " t406 = torch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n", - " # t406 = ltorch.unsqueeze(t18, 0) # t406: \"cuda:0 f32[1, 4096]\"\n", - " # t406 = prims.broadcast_in_dim(t18, [1, 4096], [1]) # t406: \"cuda:0 f32[1, 4096]\"\n", - " t407 = torch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t407 = ltorch.unsqueeze(t406, 1) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t407 = prims.broadcast_in_dim(t406, [1, 1, 4096], [0, 2]) # t407: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t406\n", - " t118 = Tensor.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t118 = ltorch.expand(t407, (1, 512, 4096)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t118 = prims.broadcast_in_dim(t407, (1, 512, 4096), (0, 1, 2)) # t118: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t407\n", - " t427 = torch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n", - " # t427 = ltorch.unsqueeze(t22, 0) # t427: \"cuda:0 f32[1, 4096]\"\n", - " # t427 = prims.broadcast_in_dim(t22, [1, 4096], [1]) # t427: \"cuda:0 f32[1, 4096]\"\n", - " t428 = torch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t428 = ltorch.unsqueeze(t427, 1) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t428 = prims.broadcast_in_dim(t427, [1, 1, 4096], [0, 2]) # t428: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t427\n", - " t169 = Tensor.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t169 = ltorch.expand(t428, (1, 512, 4096)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t169 = prims.broadcast_in_dim(t428, (1, 512, 4096), (0, 1, 2)) # t169: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t428\n", - " t430 = torch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n", - " # t430 = ltorch.unsqueeze(t19, 0) # t430: \"cuda:0 f32[1, 4096]\"\n", - " # t430 = prims.broadcast_in_dim(t19, [1, 4096], [1]) # t430: \"cuda:0 f32[1, 4096]\"\n", - " t431 = torch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t431 = ltorch.unsqueeze(t430, 1) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t431 = prims.broadcast_in_dim(t430, [1, 1, 4096], [0, 2]) # t431: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t430\n", - " t189 = Tensor.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t189 = ltorch.expand(t431, (1, 512, 4096)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t189 = prims.broadcast_in_dim(t431, (1, 512, 4096), (0, 1, 2)) # t189: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t431\n", - " t451 = torch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n", - " # t451 = ltorch.unsqueeze(t23, 0) # t451: \"cuda:0 f32[1, 4096]\"\n", - " # t451 = prims.broadcast_in_dim(t23, [1, 4096], [1]) # t451: \"cuda:0 f32[1, 4096]\"\n", - " t452 = torch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t452 = ltorch.unsqueeze(t451, 1) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t452 = prims.broadcast_in_dim(t451, [1, 1, 4096], [0, 2]) # t452: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t451\n", - " t240 = Tensor.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t240 = ltorch.expand(t452, (1, 512, 4096)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t240 = prims.broadcast_in_dim(t452, (1, 512, 4096), (0, 1, 2)) # t240: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t452\n", - " t454 = torch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n", - " # t454 = ltorch.unsqueeze(t20, 0) # t454: \"cuda:0 f32[1, 4096]\"\n", - " # t454 = prims.broadcast_in_dim(t20, [1, 4096], [1]) # t454: \"cuda:0 f32[1, 4096]\"\n", - " t455 = torch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t455 = ltorch.unsqueeze(t454, 1) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", - " # t455 = prims.broadcast_in_dim(t454, [1, 1, 4096], [0, 2]) # t455: \"cuda:0 f32[1, 1, 4096]\"\n", - " del t454\n", - " t260 = Tensor.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t260 = ltorch.expand(t455, (1, 512, 4096)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t260 = prims.broadcast_in_dim(t455, (1, 512, 4096), (0, 1, 2)) # t260: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t455\n", - " t395 = torch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n", - " # t395 = ltorch.unsqueeze(t34, 0) # t395: \"cuda:0 f32[1, 512, 128]\"\n", - " # t395 = prims.broadcast_in_dim(t34, [1, 512, 128], [1, 2]) # t395: \"cuda:0 f32[1, 512, 128]\"\n", - " del t34\n", - " t396 = torch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t396 = ltorch.unsqueeze(t395, 1) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t396 = prims.broadcast_in_dim(t395, [1, 1, 512, 128], [0, 2, 3]) # t396: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " del t395\n", - " t63 = Tensor.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t63 = ltorch.expand(t396, (1, 32, 512, 128)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t63 = prims.broadcast_in_dim(t396, (1, 32, 512, 128), (0, 1, 2, 3)) # t63: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t396\n", - " t398 = torch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n", - " # t398 = ltorch.unsqueeze(t35, 0) # t398: \"cuda:0 f32[1, 512, 128]\"\n", - " # t398 = prims.broadcast_in_dim(t35, [1, 512, 128], [1, 2]) # t398: \"cuda:0 f32[1, 512, 128]\"\n", - " del t35\n", - " t399 = torch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t399 = ltorch.unsqueeze(t398, 1) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " # t399 = prims.broadcast_in_dim(t398, [1, 1, 512, 128], [0, 2, 3]) # t399: \"cuda:0 f32[1, 1, 512, 128]\"\n", - " del t398\n", - " t65 = Tensor.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t65 = ltorch.expand(t399, (1, 32, 512, 128)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t65 = prims.broadcast_in_dim(t399, (1, 32, 512, 128), (0, 1, 2, 3)) # t65: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t399\n", - " [t44, t48] = nvFusion0(t38, t47)\n", - " # t39 = prims.mul(t38, t38) # t39: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t40 = prims.sum(t39, (2,)) # t40: \"cuda:0 f32[1, 512]\"\n", - " # t41 = prims.broadcast_in_dim(t40, [1, 512, 1], [0, 1]) # t41: \"cuda:0 f32[1, 512, 1]\"\n", - " # t42 = prims.div(t41, 4096.0) # t42: \"cuda:0 f32[1, 512, 1]\"\n", - " # t43 = prims.add(t42, 1e-05) # t43: \"cuda:0 f32[1, 512, 1]\"\n", - " # t44 = prims.rsqrt(t43) # t44: \"cuda:0 f32[1, 512, 1]\"\n", - " # t45 = prims.broadcast_in_dim(t44, (1, 512, 4096), (0, 1, 2)) # t45: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t46 = prims.mul(t38, t45) # t46: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t48 = prims.mul(t46, t47) # t48: \"cuda:0 f32[1, 512, 4096]\"\n", - " t49 = torch.nn.functional.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t49 = ltorch.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t49 = prims.linear(t48, t3, None) # t49: \"cuda:0 f32[1, 512, 12288]\"\n", - " t50 = torch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t50 = ltorch.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t50 = prims.reshape(t49, (1, 512, 32, 3, 128)) # t50: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t49\n", - " t51 = torch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t51 = ltorch.permute(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t51 = prims.transpose(t50, (0, 2, 3, 1, 4)) # t51: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t50\n", - " (t52, t53, t54) = torch.split(t51, (1, 1, 1), 2)\n", - " # (t52, t53, t54) = ltorch.split(t51, (1, 1, 1), 2)\n", - " # t52 = prims.slice_prim(t51, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t52: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t53 = prims.slice_prim(t51, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t53: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t54 = prims.slice_prim(t51, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t54: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t51\n", - " t55 = torch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t55 = ltorch.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t55 = prims.reshape(t52, (1, 32, 512, 128)) # t55: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t52\n", - " t56 = torch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t56 = ltorch.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t56 = prims.reshape(t53, (1, 32, 512, 128)) # t56: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t53\n", - " t57 = torch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t57 = ltorch.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t57 = prims.reshape(t54, (1, 32, 512, 128)) # t57: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t54\n", - " t58 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t58: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t68 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t68: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t78 = torch_slice_prim_impl(t55, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t78: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t55\n", - " t80 = torch_slice_prim_impl(t56, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t80: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t56\n", - " t60 = torch_slice_prim_impl(t58, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t60: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t59 = torch_slice_prim_impl(t58, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t59: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t69 = torch_slice_prim_impl(t68, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t69: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t70 = torch_slice_prim_impl(t68, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t70: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t61, t71] = nvFusion1(t60, t70)\n", - " # t61 = prims.neg(t60) # t61: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t71 = prims.neg(t70) # t71: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t60, t70\n", - " t62 = torch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t62 = ltorch.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t62 = prims.cat((t61, t59), -1) # t62: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t61, t59\n", - " t72 = torch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t72 = ltorch.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t72 = prims.cat((t71, t69), -1) # t72: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t71, t69\n", - " [t67, t77] = nvFusion2(t58, t62, t63, t65, t68, t72)\n", - " # t64 = prims.mul(t58, t63) # t64: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t66 = prims.mul(t62, t65) # t66: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t67 = prims.add(t64, t66) # t67: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t74 = prims.mul(t68, t63) # t74: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t76 = prims.mul(t72, t65) # t76: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t77 = prims.add(t74, t76) # t77: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t58, t62, t68, t72\n", - " t79 = torch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t79 = ltorch.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t79 = prims.cat((t67, t78), -1) # t79: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t67, t78\n", - " t81 = torch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t81 = ltorch.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t81 = prims.cat((t77, t80), -1) # t81: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t77, t80\n", - " (t82, t83, t84, t85) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t79, t81, t57, None, 0.0, True, 0.08838834764831843)\n", - " t86 = torch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t86 = ltorch.permute(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t86 = prims.transpose(t82, (0, 2, 1, 3)) # t86: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t87 = torch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t87 = ltorch.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t87 = prims.reshape(t86, (1, 512, 4096)) # t87: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t86\n", - " t88 = torch.nn.functional.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t88 = ltorch.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t88 = prims.linear(t87, t25, None) # t88: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t89, t95, t99] = nvFusion3(t38, t88, t98)\n", - " # t89 = prims.add(t88, t38) # t89: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t90 = prims.mul(t89, t89) # t90: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t91 = prims.sum(t90, (2,)) # t91: \"cuda:0 f32[1, 512]\"\n", - " # t92 = prims.broadcast_in_dim(t91, [1, 512, 1], [0, 1]) # t92: \"cuda:0 f32[1, 512, 1]\"\n", - " # t93 = prims.div(t92, 4096.0) # t93: \"cuda:0 f32[1, 512, 1]\"\n", - " # t94 = prims.add(t93, 1e-05) # t94: \"cuda:0 f32[1, 512, 1]\"\n", - " # t95 = prims.rsqrt(t94) # t95: \"cuda:0 f32[1, 512, 1]\"\n", - " # t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t97 = prims.mul(t89, t96) # t97: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t99 = prims.mul(t97, t98) # t99: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t88\n", - " t101 = torch.nn.functional.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t101 = ltorch.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t101 = prims.linear(t99, t11, None) # t101: \"cuda:0 f32[1, 512, 11008]\"\n", - " t100 = torch.nn.functional.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t100 = ltorch.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t100 = prims.linear(t99, t7, None) # t100: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t107] = nvFusion4(t100, t101)\n", - " # t102 = prims.neg(t100) # t102: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t103 = prims.exp(t102) # t103: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t104 = prims.add(1.0, t103) # t104: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t105 = prims.reciprocal(t104) # t105: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t106 = prims.mul(t100, t105) # t106: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t107 = prims.mul(t106, t101) # t107: \"cuda:0 f32[1, 512, 11008]\"\n", - " t108 = torch.nn.functional.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t108 = ltorch.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t108 = prims.linear(t107, t26, None) # t108: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t109, t115, t119] = nvFusion5(t108, t118, t89)\n", - " # t109 = prims.add(t108, t89) # t109: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t110 = prims.mul(t109, t109) # t110: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t111 = prims.sum(t110, (2,)) # t111: \"cuda:0 f32[1, 512]\"\n", - " # t112 = prims.broadcast_in_dim(t111, [1, 512, 1], [0, 1]) # t112: \"cuda:0 f32[1, 512, 1]\"\n", - " # t113 = prims.div(t112, 4096.0) # t113: \"cuda:0 f32[1, 512, 1]\"\n", - " # t114 = prims.add(t113, 1e-05) # t114: \"cuda:0 f32[1, 512, 1]\"\n", - " # t115 = prims.rsqrt(t114) # t115: \"cuda:0 f32[1, 512, 1]\"\n", - " # t116 = prims.broadcast_in_dim(t115, (1, 512, 4096), (0, 1, 2)) # t116: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t117 = prims.mul(t109, t116) # t117: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t119 = prims.mul(t117, t118) # t119: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t108\n", - " t120 = torch.nn.functional.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t120 = ltorch.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t120 = prims.linear(t119, t4, None) # t120: \"cuda:0 f32[1, 512, 12288]\"\n", - " t121 = torch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t121 = ltorch.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t121 = prims.reshape(t120, (1, 512, 32, 3, 128)) # t121: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t120\n", - " t122 = torch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t122 = ltorch.permute(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t122 = prims.transpose(t121, (0, 2, 3, 1, 4)) # t122: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t121\n", - " (t123, t124, t125) = torch.split(t122, (1, 1, 1), 2)\n", - " # (t123, t124, t125) = ltorch.split(t122, (1, 1, 1), 2)\n", - " # t123 = prims.slice_prim(t122, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t123: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t124 = prims.slice_prim(t122, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t124: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t125 = prims.slice_prim(t122, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t125: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t122\n", - " t126 = torch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t126 = ltorch.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t126 = prims.reshape(t123, (1, 32, 512, 128)) # t126: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t123\n", - " t127 = torch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t127 = ltorch.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t127 = prims.reshape(t124, (1, 32, 512, 128)) # t127: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t124\n", - " t128 = torch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t128 = ltorch.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t128 = prims.reshape(t125, (1, 32, 512, 128)) # t128: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t125\n", - " t149 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t149: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " t151 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t151: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " t129 = torch_slice_prim_impl(t126, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t129: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t126\n", - " t139 = torch_slice_prim_impl(t127, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t139: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t127\n", - " t130 = torch_slice_prim_impl(t129, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t130: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t131 = torch_slice_prim_impl(t129, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t131: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t141 = torch_slice_prim_impl(t139, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t141: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t140 = torch_slice_prim_impl(t139, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t140: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t132, t142] = nvFusion6(t131, t141)\n", - " # t132 = prims.neg(t131) # t132: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t142 = prims.neg(t141) # t142: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t131, t141\n", - " t143 = torch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t143 = ltorch.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t143 = prims.cat((t142, t140), -1) # t143: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t142, t140\n", - " t133 = torch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t133 = ltorch.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t133 = prims.cat((t132, t130), -1) # t133: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t132, t130\n", - " [t138, t148] = nvFusion7(t129, t133, t139, t143, t63, t65)\n", - " # t145 = prims.mul(t139, t63) # t145: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t147 = prims.mul(t143, t65) # t147: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t148 = prims.add(t145, t147) # t148: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t135 = prims.mul(t129, t63) # t135: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t137 = prims.mul(t133, t65) # t137: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t138 = prims.add(t135, t137) # t138: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t129, t133, t139, t143\n", - " t150 = torch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t150 = ltorch.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t150 = prims.cat((t138, t149), -1) # t150: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t138, t149\n", - " t152 = torch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t152 = ltorch.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t152 = prims.cat((t148, t151), -1) # t152: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t148, t151\n", - " (t153, t154, t155, t156) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t150, t152, t128, None, 0.0, True, 0.08838834764831843)\n", - " t157 = torch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t157 = ltorch.permute(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t157 = prims.transpose(t153, (0, 2, 1, 3)) # t157: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t158 = torch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t158 = ltorch.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t158 = prims.reshape(t157, (1, 512, 4096)) # t158: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t157\n", - " t159 = torch.nn.functional.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t159 = ltorch.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t159 = prims.linear(t158, t27, None) # t159: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t160, t166, t170] = nvFusion8(t109, t159, t169)\n", - " # t160 = prims.add(t159, t109) # t160: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t161 = prims.mul(t160, t160) # t161: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t162 = prims.sum(t161, (2,)) # t162: \"cuda:0 f32[1, 512]\"\n", - " # t163 = prims.broadcast_in_dim(t162, [1, 512, 1], [0, 1]) # t163: \"cuda:0 f32[1, 512, 1]\"\n", - " # t164 = prims.div(t163, 4096.0) # t164: \"cuda:0 f32[1, 512, 1]\"\n", - " # t165 = prims.add(t164, 1e-05) # t165: \"cuda:0 f32[1, 512, 1]\"\n", - " # t166 = prims.rsqrt(t165) # t166: \"cuda:0 f32[1, 512, 1]\"\n", - " # t167 = prims.broadcast_in_dim(t166, (1, 512, 4096), (0, 1, 2)) # t167: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t168 = prims.mul(t160, t167) # t168: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t170 = prims.mul(t168, t169) # t170: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t159\n", - " t172 = torch.nn.functional.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t172 = ltorch.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t172 = prims.linear(t170, t12, None) # t172: \"cuda:0 f32[1, 512, 11008]\"\n", - " t171 = torch.nn.functional.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t171 = ltorch.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t171 = prims.linear(t170, t8, None) # t171: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t178] = nvFusion9(t171, t172)\n", - " # t173 = prims.neg(t171) # t173: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t174 = prims.exp(t173) # t174: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t175 = prims.add(1.0, t174) # t175: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t176 = prims.reciprocal(t175) # t176: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t177 = prims.mul(t171, t176) # t177: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t178 = prims.mul(t177, t172) # t178: \"cuda:0 f32[1, 512, 11008]\"\n", - " t179 = torch.nn.functional.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t179 = ltorch.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t179 = prims.linear(t178, t28, None) # t179: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t180, t186, t190] = nvFusion10(t160, t179, t189)\n", - " # t180 = prims.add(t179, t160) # t180: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t181 = prims.mul(t180, t180) # t181: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t182 = prims.sum(t181, (2,)) # t182: \"cuda:0 f32[1, 512]\"\n", - " # t183 = prims.broadcast_in_dim(t182, [1, 512, 1], [0, 1]) # t183: \"cuda:0 f32[1, 512, 1]\"\n", - " # t184 = prims.div(t183, 4096.0) # t184: \"cuda:0 f32[1, 512, 1]\"\n", - " # t185 = prims.add(t184, 1e-05) # t185: \"cuda:0 f32[1, 512, 1]\"\n", - " # t186 = prims.rsqrt(t185) # t186: \"cuda:0 f32[1, 512, 1]\"\n", - " # t187 = prims.broadcast_in_dim(t186, (1, 512, 4096), (0, 1, 2)) # t187: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t188 = prims.mul(t180, t187) # t188: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t190 = prims.mul(t188, t189) # t190: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t179\n", - " t191 = torch.nn.functional.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t191 = ltorch.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t191 = prims.linear(t190, t5, None) # t191: \"cuda:0 f32[1, 512, 12288]\"\n", - " t192 = torch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t192 = ltorch.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t192 = prims.reshape(t191, (1, 512, 32, 3, 128)) # t192: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t191\n", - " t193 = torch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t193 = ltorch.permute(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t193 = prims.transpose(t192, (0, 2, 3, 1, 4)) # t193: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t192\n", - " (t194, t195, t196) = torch.split(t193, (1, 1, 1), 2)\n", - " # (t194, t195, t196) = ltorch.split(t193, (1, 1, 1), 2)\n", - " # t194 = prims.slice_prim(t193, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t194: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t195 = prims.slice_prim(t193, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t195: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t196 = prims.slice_prim(t193, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t196: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t193\n", - " t197 = torch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t197 = ltorch.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t197 = prims.reshape(t194, (1, 32, 512, 128)) # t197: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t194\n", - " t198 = torch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t198 = ltorch.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t198 = prims.reshape(t195, (1, 32, 512, 128)) # t198: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t195\n", - " t199 = torch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t199 = ltorch.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t199 = prims.reshape(t196, (1, 32, 512, 128)) # t199: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t196\n", - " t200 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t200: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t210 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t210: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t220 = torch_slice_prim_impl(t197, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t220: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t197\n", - " t222 = torch_slice_prim_impl(t198, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t222: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t198\n", - " t201 = torch_slice_prim_impl(t200, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t201: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t202 = torch_slice_prim_impl(t200, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t202: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t211 = torch_slice_prim_impl(t210, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t211: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t212 = torch_slice_prim_impl(t210, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t212: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t203, t213] = nvFusion11(t202, t212)\n", - " # t203 = prims.neg(t202) # t203: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t202, t212\n", - " t214 = torch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t214 = ltorch.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t214 = prims.cat((t213, t211), -1) # t214: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t213, t211\n", - " t204 = torch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t204 = ltorch.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t204 = prims.cat((t203, t201), -1) # t204: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t203, t201\n", - " [t209, t219] = nvFusion12(t200, t204, t210, t214, t63, t65)\n", - " # t216 = prims.mul(t210, t63) # t216: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t218 = prims.mul(t214, t65) # t218: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t219 = prims.add(t216, t218) # t219: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t206 = prims.mul(t200, t63) # t206: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t208 = prims.mul(t204, t65) # t208: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t209 = prims.add(t206, t208) # t209: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t200, t204, t210, t214\n", - " t223 = torch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t223 = ltorch.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t223 = prims.cat((t219, t222), -1) # t223: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t219, t222\n", - " t221 = torch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t221 = ltorch.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t221 = prims.cat((t209, t220), -1) # t221: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t209, t220\n", - " (t224, t225, t226, t227) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t221, t223, t199, None, 0.0, True, 0.08838834764831843)\n", - " t228 = torch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t228 = ltorch.permute(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t228 = prims.transpose(t224, (0, 2, 1, 3)) # t228: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t229 = torch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t229 = ltorch.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t229 = prims.reshape(t228, (1, 512, 4096)) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t228\n", - " t230 = torch.nn.functional.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t230 = ltorch.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t230 = prims.linear(t229, t29, None) # t230: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t231, t237, t241] = nvFusion13(t180, t230, t240)\n", - " # t231 = prims.add(t230, t180) # t231: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t232 = prims.mul(t231, t231) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", + " t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1867 = ltorch.reshape(t0, [512]) # t1867: \"cuda:0 i64[512]\"\n", + " # t1867 = prims.reshape(t0, (512,)) # t1867: \"cuda:0 i64[512]\"\n", + " # t1868 = prims.take(t117, t1867, 0) # t1868: \"cuda:0 bf16[512, 4096]\"\n", + " # t122 = ltorch.reshape(t1868, [1, 512, 4096]) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t122 = prims.reshape(t1868, (1, 512, 4096)) # t122: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1]) # t118: \"cuda:0 f32[512, 128]\"\n", + " t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1]) # t119: \"cuda:0 f32[512, 128]\"\n", + " t2015 = torch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " # t2015 = ltorch.unsqueeze(t53, 0) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " # t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1]) # t2015: \"cuda:0 bf16[1, 4096]\"\n", + " t2016 = torch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2016 = ltorch.unsqueeze(t2015, 1) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2]) # t2016: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2015\n", + " t133 = Tensor.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t133 = ltorch.expand(t2016, (1, 512, 4096)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2)) # t133: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2016\n", + " t2356 = torch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " # t2356 = ltorch.unsqueeze(t82, 0) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " # t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1]) # t2356: \"cuda:0 bf16[1, 4096]\"\n", + " t2357 = torch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2357 = ltorch.unsqueeze(t2356, 1) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2]) # t2357: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2356\n", + " t1609 = Tensor.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1609 = ltorch.expand(t2357, (1, 512, 4096)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2)) # t1609: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2357\n", + " t2359 = torch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " # t2359 = ltorch.unsqueeze(t58, 0) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " # t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1]) # t2359: \"cuda:0 bf16[1, 4096]\"\n", + " t2360 = torch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2360 = ltorch.unsqueeze(t2359, 1) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2]) # t2360: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2359\n", + " t1645 = Tensor.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1645 = ltorch.expand(t2360, (1, 512, 4096)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2)) # t1645: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2360\n", + " t2044 = torch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " # t2044 = ltorch.unsqueeze(t69, 0) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " # t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1]) # t2044: \"cuda:0 bf16[1, 4096]\"\n", + " t2045 = torch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2045 = ltorch.unsqueeze(t2044, 1) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2]) # t2045: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2044\n", + " t205 = Tensor.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t205 = ltorch.expand(t2045, (1, 512, 4096)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2)) # t205: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2045\n", + " t2380 = torch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " # t2380 = ltorch.unsqueeze(t83, 0) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " # t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1]) # t2380: \"cuda:0 bf16[1, 4096]\"\n", + " t2381 = torch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2381 = ltorch.unsqueeze(t2380, 1) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2]) # t2381: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2380\n", + " t1717 = Tensor.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1717 = ltorch.expand(t2381, (1, 512, 4096)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2)) # t1717: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2381\n", + " t2047 = torch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " # t2047 = ltorch.unsqueeze(t60, 0) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " # t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1]) # t2047: \"cuda:0 bf16[1, 4096]\"\n", + " t2048 = torch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2048 = ltorch.unsqueeze(t2047, 1) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2]) # t2048: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2047\n", + " t241 = Tensor.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t241 = ltorch.expand(t2048, (1, 512, 4096)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2)) # t241: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2048\n", + " t2383 = torch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " # t2383 = ltorch.unsqueeze(t59, 0) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " # t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1]) # t2383: \"cuda:0 bf16[1, 4096]\"\n", + " t2384 = torch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2384 = ltorch.unsqueeze(t2383, 1) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2]) # t2384: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2383\n", + " t1753 = Tensor.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1753 = ltorch.expand(t2384, (1, 512, 4096)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2)) # t1753: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2384\n", + " t2068 = torch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " # t2068 = ltorch.unsqueeze(t70, 0) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " # t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1]) # t2068: \"cuda:0 bf16[1, 4096]\"\n", + " t2069 = torch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2069 = ltorch.unsqueeze(t2068, 1) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2]) # t2069: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2068\n", + " t313 = Tensor.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t313 = ltorch.expand(t2069, (1, 512, 4096)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2)) # t313: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2069\n", + " t2404 = torch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " # t2404 = ltorch.unsqueeze(t84, 0) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " # t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1]) # t2404: \"cuda:0 bf16[1, 4096]\"\n", + " t2405 = torch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2405 = ltorch.unsqueeze(t2404, 1) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2]) # t2405: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2404\n", + " t1825 = Tensor.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1825 = ltorch.expand(t2405, (1, 512, 4096)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2)) # t1825: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2405\n", + " t2071 = torch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " # t2071 = ltorch.unsqueeze(t61, 0) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " # t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1]) # t2071: \"cuda:0 bf16[1, 4096]\"\n", + " t2072 = torch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2072 = ltorch.unsqueeze(t2071, 1) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2]) # t2072: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2071\n", + " t349 = Tensor.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t349 = ltorch.expand(t2072, (1, 512, 4096)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2)) # t349: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2072\n", + " t2407 = torch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " # t2407 = ltorch.unsqueeze(t52, 0) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " # t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1]) # t2407: \"cuda:0 bf16[1, 4096]\"\n", + " t2408 = torch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2408 = ltorch.unsqueeze(t2407, 1) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2]) # t2408: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2407\n", + " t1861 = Tensor.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1861 = ltorch.expand(t2408, (1, 512, 4096)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2)) # t1861: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2408\n", + " t2095 = torch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " # t2095 = ltorch.unsqueeze(t62, 0) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " # t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1]) # t2095: \"cuda:0 bf16[1, 4096]\"\n", + " t2096 = torch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2096 = ltorch.unsqueeze(t2095, 1) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2]) # t2096: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2095\n", + " t457 = Tensor.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t457 = ltorch.expand(t2096, (1, 512, 4096)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2)) # t457: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2096\n", + " t2092 = torch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " # t2092 = ltorch.unsqueeze(t71, 0) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " # t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1]) # t2092: \"cuda:0 bf16[1, 4096]\"\n", + " t2093 = torch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2093 = ltorch.unsqueeze(t2092, 1) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2]) # t2093: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2092\n", + " t421 = Tensor.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t421 = ltorch.expand(t2093, (1, 512, 4096)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2)) # t421: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2093\n", + " t2116 = torch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " # t2116 = ltorch.unsqueeze(t72, 0) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " # t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1]) # t2116: \"cuda:0 bf16[1, 4096]\"\n", + " t2117 = torch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2117 = ltorch.unsqueeze(t2116, 1) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2]) # t2117: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2116\n", + " t529 = Tensor.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t529 = ltorch.expand(t2117, (1, 512, 4096)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2)) # t529: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2117\n", + " t2119 = torch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " # t2119 = ltorch.unsqueeze(t63, 0) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " # t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1]) # t2119: \"cuda:0 bf16[1, 4096]\"\n", + " t2120 = torch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2120 = ltorch.unsqueeze(t2119, 1) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2]) # t2120: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2119\n", + " t565 = Tensor.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t565 = ltorch.expand(t2120, (1, 512, 4096)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2)) # t565: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2120\n", + " t2140 = torch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " # t2140 = ltorch.unsqueeze(t73, 0) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " # t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1]) # t2140: \"cuda:0 bf16[1, 4096]\"\n", + " t2141 = torch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2141 = ltorch.unsqueeze(t2140, 1) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2]) # t2141: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2140\n", + " t637 = Tensor.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t637 = ltorch.expand(t2141, (1, 512, 4096)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2)) # t637: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2141\n", + " t2143 = torch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " # t2143 = ltorch.unsqueeze(t64, 0) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " # t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1]) # t2143: \"cuda:0 bf16[1, 4096]\"\n", + " t2144 = torch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2144 = ltorch.unsqueeze(t2143, 1) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2]) # t2144: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2143\n", + " t673 = Tensor.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t673 = ltorch.expand(t2144, (1, 512, 4096)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2)) # t673: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2144\n", + " t2164 = torch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " # t2164 = ltorch.unsqueeze(t74, 0) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " # t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1]) # t2164: \"cuda:0 bf16[1, 4096]\"\n", + " t2165 = torch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2165 = ltorch.unsqueeze(t2164, 1) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2]) # t2165: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2164\n", + " t745 = Tensor.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t745 = ltorch.expand(t2165, (1, 512, 4096)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2)) # t745: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2165\n", + " t2167 = torch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " # t2167 = ltorch.unsqueeze(t65, 0) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " # t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1]) # t2167: \"cuda:0 bf16[1, 4096]\"\n", + " t2168 = torch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2168 = ltorch.unsqueeze(t2167, 1) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2]) # t2168: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2167\n", + " t781 = Tensor.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t781 = ltorch.expand(t2168, (1, 512, 4096)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2)) # t781: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2168\n", + " t2188 = torch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " # t2188 = ltorch.unsqueeze(t75, 0) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " # t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1]) # t2188: \"cuda:0 bf16[1, 4096]\"\n", + " t2189 = torch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2189 = ltorch.unsqueeze(t2188, 1) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2]) # t2189: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2188\n", + " t853 = Tensor.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t853 = ltorch.expand(t2189, (1, 512, 4096)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2)) # t853: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2189\n", + " t2191 = torch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " # t2191 = ltorch.unsqueeze(t66, 0) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " # t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1]) # t2191: \"cuda:0 bf16[1, 4096]\"\n", + " t2192 = torch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2192 = ltorch.unsqueeze(t2191, 1) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2]) # t2192: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2191\n", + " t889 = Tensor.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t889 = ltorch.expand(t2192, (1, 512, 4096)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2)) # t889: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2192\n", + " t2212 = torch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " # t2212 = ltorch.unsqueeze(t76, 0) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " # t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1]) # t2212: \"cuda:0 bf16[1, 4096]\"\n", + " t2213 = torch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2213 = ltorch.unsqueeze(t2212, 1) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2]) # t2213: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2212\n", + " t961 = Tensor.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t961 = ltorch.expand(t2213, (1, 512, 4096)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2)) # t961: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2213\n", + " t2215 = torch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " # t2215 = ltorch.unsqueeze(t67, 0) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " # t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1]) # t2215: \"cuda:0 bf16[1, 4096]\"\n", + " t2216 = torch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2216 = ltorch.unsqueeze(t2215, 1) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2]) # t2216: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2215\n", + " t997 = Tensor.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t997 = ltorch.expand(t2216, (1, 512, 4096)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2)) # t997: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2216\n", + " t2236 = torch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " # t2236 = ltorch.unsqueeze(t77, 0) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " # t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1]) # t2236: \"cuda:0 bf16[1, 4096]\"\n", + " t2237 = torch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2237 = ltorch.unsqueeze(t2236, 1) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2]) # t2237: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2236\n", + " t1069 = Tensor.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1069 = ltorch.expand(t2237, (1, 512, 4096)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2)) # t1069: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2237\n", + " t2239 = torch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " # t2239 = ltorch.unsqueeze(t68, 0) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " # t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1]) # t2239: \"cuda:0 bf16[1, 4096]\"\n", + " t2240 = torch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2240 = ltorch.unsqueeze(t2239, 1) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2]) # t2240: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2239\n", + " t1105 = Tensor.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1105 = ltorch.expand(t2240, (1, 512, 4096)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2)) # t1105: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2240\n", + " t2260 = torch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " # t2260 = ltorch.unsqueeze(t78, 0) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " # t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1]) # t2260: \"cuda:0 bf16[1, 4096]\"\n", + " t2261 = torch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2261 = ltorch.unsqueeze(t2260, 1) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2]) # t2261: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2260\n", + " t1177 = Tensor.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1177 = ltorch.expand(t2261, (1, 512, 4096)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2)) # t1177: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2261\n", + " t2263 = torch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " # t2263 = ltorch.unsqueeze(t54, 0) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " # t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1]) # t2263: \"cuda:0 bf16[1, 4096]\"\n", + " t2264 = torch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2264 = ltorch.unsqueeze(t2263, 1) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2]) # t2264: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2263\n", + " t1213 = Tensor.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1213 = ltorch.expand(t2264, (1, 512, 4096)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2)) # t1213: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2264\n", + " t2284 = torch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " # t2284 = ltorch.unsqueeze(t79, 0) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " # t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1]) # t2284: \"cuda:0 bf16[1, 4096]\"\n", + " t2285 = torch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2285 = ltorch.unsqueeze(t2284, 1) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2]) # t2285: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2284\n", + " t1285 = Tensor.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1285 = ltorch.expand(t2285, (1, 512, 4096)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2)) # t1285: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2285\n", + " t2287 = torch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " # t2287 = ltorch.unsqueeze(t55, 0) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " # t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1]) # t2287: \"cuda:0 bf16[1, 4096]\"\n", + " t2288 = torch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2288 = ltorch.unsqueeze(t2287, 1) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2]) # t2288: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2287\n", + " t1321 = Tensor.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1321 = ltorch.expand(t2288, (1, 512, 4096)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2)) # t1321: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2288\n", + " t2308 = torch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " # t2308 = ltorch.unsqueeze(t80, 0) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " # t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1]) # t2308: \"cuda:0 bf16[1, 4096]\"\n", + " t2309 = torch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2309 = ltorch.unsqueeze(t2308, 1) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2]) # t2309: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2308\n", + " t1393 = Tensor.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1393 = ltorch.expand(t2309, (1, 512, 4096)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t1393: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2309\n", + " t2311 = torch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " # t2311 = ltorch.unsqueeze(t56, 0) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " # t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1]) # t2311: \"cuda:0 bf16[1, 4096]\"\n", + " t2312 = torch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2312 = ltorch.unsqueeze(t2311, 1) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2]) # t2312: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2311\n", + " t1429 = Tensor.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1429 = ltorch.expand(t2312, (1, 512, 4096)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2)) # t1429: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2312\n", + " t2332 = torch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " # t2332 = ltorch.unsqueeze(t81, 0) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " # t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1]) # t2332: \"cuda:0 bf16[1, 4096]\"\n", + " t2333 = torch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2333 = ltorch.unsqueeze(t2332, 1) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2]) # t2333: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2332\n", + " t1501 = Tensor.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1501 = ltorch.expand(t2333, (1, 512, 4096)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2)) # t1501: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2333\n", + " t2335 = torch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " # t2335 = ltorch.unsqueeze(t57, 0) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " # t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1]) # t2335: \"cuda:0 bf16[1, 4096]\"\n", + " t2336 = torch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2336 = ltorch.unsqueeze(t2335, 1) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " # t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2]) # t2336: \"cuda:0 bf16[1, 1, 4096]\"\n", + " del t2335\n", + " t1537 = Tensor.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1537 = ltorch.expand(t2336, (1, 512, 4096)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2)) # t1537: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t2336\n", + " t2036 = torch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2036 = ltorch.unsqueeze(t118, 0) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2]) # t2036: \"cuda:0 f32[1, 512, 128]\"\n", + " del t118\n", + " t2037 = torch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2037 = ltorch.unsqueeze(t2036, 1) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3]) # t2037: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t2036\n", + " t154 = Tensor.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t154 = ltorch.expand(t2037, (1, 32, 512, 128)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3)) # t154: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t2037\n", + " t2039 = torch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2039 = ltorch.unsqueeze(t119, 0) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " # t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2]) # t2039: \"cuda:0 f32[1, 512, 128]\"\n", + " del t119\n", + " t2040 = torch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2040 = ltorch.unsqueeze(t2039, 1) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " # t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3]) # t2040: \"cuda:0 f32[1, 1, 512, 128]\"\n", + " del t2039\n", + " t157 = Tensor.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t157 = ltorch.expand(t2040, (1, 32, 512, 128)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3)) # t157: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " del t2040\n", + " [t129, t137] = nvFusion0(t122, t133)\n", + " # t123 = prims.convert_element_type(t122, dtypes.float32) # t123: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t124 = prims.mul(t123, t123) # t124: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t125 = prims.sum(t124, (2,)) # t125: \"cuda:0 f32[1, 512]\"\n", + " # t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1]) # t126: \"cuda:0 f32[1, 512, 1]\"\n", + " # t127 = prims.div(t126, 4096.0) # t127: \"cuda:0 f32[1, 512, 1]\"\n", + " # t128 = prims.add(t127, 1e-05) # t128: \"cuda:0 f32[1, 512, 1]\"\n", + " # t129 = prims.rsqrt(t128) # t129: \"cuda:0 f32[1, 512, 1]\"\n", + " # t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2)) # t130: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t131 = prims.mul(t123, t130) # t131: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t135 = prims.convert_element_type(t133, dtypes.float32) # t135: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t136 = prims.mul(t131, t135) # t136: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t138 = torch.nn.functional.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t138 = ltorch.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t138 = prims.linear(t137, t3, None) # t138: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t139 = torch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t139 = prims.reshape(t138, (1, 512, 32, 3, 128)) # t139: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t138\n", + " t140 = torch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t140 = ltorch.permute(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t140 = prims.transpose(t139, (0, 2, 3, 1, 4)) # t140: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t139\n", + " (t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)\n", + " # (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)\n", + " # t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t141: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t142: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t143: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t140\n", + " t144 = torch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t144 = ltorch.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t144 = prims.reshape(t141, (1, 32, 512, 128)) # t144: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t141\n", + " t145 = torch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t145 = ltorch.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t145 = prims.reshape(t142, (1, 32, 512, 128)) # t145: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t142\n", + " t146 = torch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t146 = ltorch.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t146 = prims.reshape(t143, (1, 32, 512, 128)) # t146: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t143\n", + " t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t147: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t162: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t177: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t144\n", + " t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t179: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t145\n", + " t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t149: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t148: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t163: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t164: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t152, t167] = nvFusion1(t147, t149, t162, t164)\n", + " # t150 = prims.convert_element_type(t149, dtypes.float32) # t150: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t151 = prims.neg(t150) # t151: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t152 = prims.convert_element_type(t151, dtypes.bfloat16) # t152: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t165 = prims.convert_element_type(t164, dtypes.float32) # t165: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t166 = prims.neg(t165) # t166: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t167 = prims.convert_element_type(t166, dtypes.bfloat16) # t167: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t149, t164\n", + " t168 = torch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t168 = ltorch.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t168 = prims.cat((t167, t163), -1) # t168: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t167, t163\n", + " t153 = torch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t153 = ltorch.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t153 = prims.cat((t152, t148), -1) # t153: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t152, t148\n", + " [t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)\n", + " # t155 = prims.convert_element_type(t147, dtypes.float32) # t155: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t170 = prims.convert_element_type(t162, dtypes.float32) # t170: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t156 = prims.mul(t155, t154) # t156: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t158 = prims.convert_element_type(t153, dtypes.float32) # t158: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t159 = prims.mul(t158, t157) # t159: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t160 = prims.add(t156, t159) # t160: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t161 = prims.convert_element_type(t160, dtypes.bfloat16) # t161: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t171 = prims.mul(t170, t154) # t171: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t173 = prims.convert_element_type(t168, dtypes.float32) # t173: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t174 = prims.mul(t173, t157) # t174: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t175 = prims.add(t171, t174) # t175: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t176 = prims.convert_element_type(t175, dtypes.bfloat16) # t176: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t147, t153, t162, t168\n", + " t178 = torch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t178 = ltorch.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t178 = prims.cat((t161, t177), -1) # t178: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t161, t177\n", + " t180 = torch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t180 = ltorch.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t180 = prims.cat((t176, t179), -1) # t180: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t176, t179\n", + " (t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)\n", + " t188 = torch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t188 = ltorch.permute(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t188 = prims.transpose(t181, (0, 2, 1, 3)) # t188: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t189 = torch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t189 = ltorch.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t189 = prims.reshape(t188, (1, 512, 4096)) # t189: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t188\n", + " t190 = torch.nn.functional.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t190 = ltorch.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t190 = prims.linear(t189, t85, None) # t190: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t194, t201, t209] = nvFusion3(t122, t190, t205)\n", + " # t191 = prims.convert_element_type(t190, dtypes.float32) # t191: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t192 = prims.convert_element_type(t122, dtypes.float32) # t192: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t193 = prims.add(t191, t192) # t193: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t194 = prims.convert_element_type(t193, dtypes.bfloat16) # t194: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t196 = prims.mul(t193, t193) # t196: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t197 = prims.sum(t196, (2,)) # t197: \"cuda:0 f32[1, 512]\"\n", + " # t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1]) # t198: \"cuda:0 f32[1, 512, 1]\"\n", + " # t199 = prims.div(t198, 4096.0) # t199: \"cuda:0 f32[1, 512, 1]\"\n", + " # t200 = prims.add(t199, 1e-05) # t200: \"cuda:0 f32[1, 512, 1]\"\n", + " # t201 = prims.rsqrt(t200) # t201: \"cuda:0 f32[1, 512, 1]\"\n", + " # t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2)) # t202: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t203 = prims.mul(t193, t202) # t203: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t207 = prims.convert_element_type(t205, dtypes.float32) # t207: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t208 = prims.mul(t203, t207) # t208: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t209 = prims.convert_element_type(t208, dtypes.bfloat16) # t209: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t210 = torch.nn.functional.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t210 = ltorch.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t210 = prims.linear(t209, t19, None) # t210: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t211 = torch.nn.functional.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t211 = ltorch.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t211 = prims.linear(t209, t35, None) # t211: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t225] = nvFusion4(t210, t211)\n", + " # t212 = prims.convert_element_type(t210, dtypes.float32) # t212: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t213 = prims.neg(t212) # t213: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t214 = prims.exp(t213) # t214: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t215 = prims.add(1.0, t214) # t215: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t216 = prims.reciprocal(t215) # t216: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t220 = prims.mul(t212, t216) # t220: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t223 = prims.convert_element_type(t211, dtypes.float32) # t223: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t224 = prims.mul(t220, t223) # t224: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t225 = prims.convert_element_type(t224, dtypes.bfloat16) # t225: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t226 = torch.nn.functional.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t226 = ltorch.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t226 = prims.linear(t225, t86, None) # t226: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t230, t237, t245] = nvFusion5(t194, t226, t241)\n", + " # t228 = prims.convert_element_type(t194, dtypes.float32) # t228: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t229 = prims.add(t227, t228) # t229: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t232 = prims.mul(t229, t229) # t232: \"cuda:0 f32[1, 512, 4096]\"\n", " # t233 = prims.sum(t232, (2,)) # t233: \"cuda:0 f32[1, 512]\"\n", " # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1]) # t234: \"cuda:0 f32[1, 512, 1]\"\n", " # t235 = prims.div(t234, 4096.0) # t235: \"cuda:0 f32[1, 512, 1]\"\n", " # t236 = prims.add(t235, 1e-05) # t236: \"cuda:0 f32[1, 512, 1]\"\n", " # t237 = prims.rsqrt(t236) # t237: \"cuda:0 f32[1, 512, 1]\"\n", " # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2)) # t238: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t239 = prims.mul(t231, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t241 = prims.mul(t239, t240) # t241: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t230\n", - " t242 = torch.nn.functional.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t242 = ltorch.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t242 = prims.linear(t241, t9, None) # t242: \"cuda:0 f32[1, 512, 11008]\"\n", - " t243 = torch.nn.functional.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t243 = ltorch.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t243 = prims.linear(t241, t13, None) # t243: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t249] = nvFusion14(t242, t243)\n", - " # t244 = prims.neg(t242) # t244: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t245 = prims.exp(t244) # t245: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t246 = prims.add(1.0, t245) # t246: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t247 = prims.reciprocal(t246) # t247: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t248 = prims.mul(t242, t247) # t248: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t249 = prims.mul(t248, t243) # t249: \"cuda:0 f32[1, 512, 11008]\"\n", - " t250 = torch.nn.functional.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t250 = ltorch.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t250 = prims.linear(t249, t30, None) # t250: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t251, t257, t261] = nvFusion15(t231, t250, t260)\n", - " # t251 = prims.add(t250, t231) # t251: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t252 = prims.mul(t251, t251) # t252: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t253 = prims.sum(t252, (2,)) # t253: \"cuda:0 f32[1, 512]\"\n", - " # t254 = prims.broadcast_in_dim(t253, [1, 512, 1], [0, 1]) # t254: \"cuda:0 f32[1, 512, 1]\"\n", - " # t255 = prims.div(t254, 4096.0) # t255: \"cuda:0 f32[1, 512, 1]\"\n", - " # t256 = prims.add(t255, 1e-05) # t256: \"cuda:0 f32[1, 512, 1]\"\n", - " # t257 = prims.rsqrt(t256) # t257: \"cuda:0 f32[1, 512, 1]\"\n", - " # t258 = prims.broadcast_in_dim(t257, (1, 512, 4096), (0, 1, 2)) # t258: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t259 = prims.mul(t251, t258) # t259: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t261 = prims.mul(t259, t260) # t261: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t239 = prims.mul(t229, t238) # t239: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t243 = prims.convert_element_type(t241, dtypes.float32) # t243: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t244 = prims.mul(t239, t243) # t244: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t245 = prims.convert_element_type(t244, dtypes.bfloat16) # t245: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t246 = torch.nn.functional.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t246 = ltorch.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t246 = prims.linear(t245, t4, None) # t246: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t247 = torch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t247 = prims.reshape(t246, (1, 512, 32, 3, 128)) # t247: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t246\n", + " t248 = torch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t248 = ltorch.permute(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t248 = prims.transpose(t247, (0, 2, 3, 1, 4)) # t248: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t247\n", + " (t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)\n", + " # (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)\n", + " # t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t249: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t250: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t251: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t248\n", + " t252 = torch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t252 = ltorch.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t252 = prims.reshape(t249, (1, 32, 512, 128)) # t252: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t249\n", + " t253 = torch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t253 = ltorch.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t253 = prims.reshape(t250, (1, 32, 512, 128)) # t253: \"cuda:0 bf16[1, 32, 512, 128]\"\n", " del t250\n", - " t262 = torch.nn.functional.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t262 = ltorch.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", - " # t262 = prims.linear(t261, t6, None) # t262: \"cuda:0 f32[1, 512, 12288]\"\n", - " t263 = torch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t263 = ltorch.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " # t263 = prims.reshape(t262, (1, 512, 32, 3, 128)) # t263: \"cuda:0 f32[1, 512, 32, 3, 128]\"\n", - " del t262\n", - " t264 = torch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t264 = ltorch.permute(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " # t264 = prims.transpose(t263, (0, 2, 3, 1, 4)) # t264: \"cuda:0 f32[1, 32, 3, 512, 128]\"\n", - " del t263\n", - " (t265, t266, t267) = torch.split(t264, (1, 1, 1), 2)\n", - " # (t265, t266, t267) = ltorch.split(t264, (1, 1, 1), 2)\n", - " # t265 = prims.slice_prim(t264, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t265: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t266 = prims.slice_prim(t264, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t266: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " # t267 = prims.slice_prim(t264, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t267: \"cuda:0 f32[1, 32, 1, 512, 128]\"\n", - " del t264\n", - " t268 = torch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t268 = ltorch.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t268 = prims.reshape(t265, (1, 32, 512, 128)) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t265\n", - " t269 = torch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t269 = ltorch.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t269 = prims.reshape(t266, (1, 32, 512, 128)) # t269: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t266\n", - " t270 = torch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t270 = ltorch.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t270 = prims.reshape(t267, (1, 32, 512, 128)) # t270: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t267\n", - " t271 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t271: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t281 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " t291 = torch_slice_prim_impl(t268, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t291: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t268\n", - " t293 = torch_slice_prim_impl(t269, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t293: \"cuda:0 f32[1, 32, 512, 0]\"\n", - " del t269\n", - " t272 = torch_slice_prim_impl(t271, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t272: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t273 = torch_slice_prim_impl(t271, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t282 = torch_slice_prim_impl(t281, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t282: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " t283 = torch_slice_prim_impl(t281, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t283: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " [t274, t284] = nvFusion16(t273, t283)\n", + " t254 = torch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t254 = ltorch.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t254 = prims.reshape(t251, (1, 32, 512, 128)) # t254: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t251\n", + " t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t285: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t287: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t255: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t252\n", + " t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t270: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t253\n", + " t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t256: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t257: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t272: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t271: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t260, t275] = nvFusion6(t255, t257, t270, t272)\n", + " # t258 = prims.convert_element_type(t257, dtypes.float32) # t258: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t259 = prims.neg(t258) # t259: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t273 = prims.convert_element_type(t272, dtypes.float32) # t273: \"cuda:0 f32[1, 32, 512, 64]\"\n", " # t274 = prims.neg(t273) # t274: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " # t284 = prims.neg(t283) # t284: \"cuda:0 f32[1, 32, 512, 64]\"\n", - " del t273, t283\n", - " t275 = torch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t275 = ltorch.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t275 = prims.cat((t274, t272), -1) # t275: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t274, t272\n", - " t285 = torch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t285 = ltorch.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t285 = prims.cat((t284, t282), -1) # t285: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t284, t282\n", - " [t280, t290] = nvFusion17(t271, t275, t281, t285, t63, t65)\n", - " # t277 = prims.mul(t271, t63) # t277: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t279 = prims.mul(t275, t65) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t280 = prims.add(t277, t279) # t280: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t287 = prims.mul(t281, t63) # t287: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t289 = prims.mul(t285, t65) # t289: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t290 = prims.add(t287, t289) # t290: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t271, t275, t281, t285\n", - " t292 = torch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t292 = ltorch.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t292 = prims.cat((t280, t291), -1) # t292: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t280, t291\n", - " t294 = torch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t294 = ltorch.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " # t294 = prims.cat((t290, t293), -1) # t294: \"cuda:0 f32[1, 32, 512, 128]\"\n", - " del t290, t293\n", - " (t295, t296, t297, t298) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t292, t294, t270, None, 0.0, True, 0.08838834764831843)\n", - " t299 = torch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t299 = ltorch.permute(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " # t299 = prims.transpose(t295, (0, 2, 1, 3)) # t299: \"cuda:0 f32[1, 512, 32, 128]\"\n", - " t300 = torch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t300 = ltorch.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t300 = prims.reshape(t299, (1, 512, 4096)) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t299\n", - " t301 = torch.nn.functional.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t301 = ltorch.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t301 = prims.linear(t300, t31, None) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t302, t308, t312] = nvFusion18(t251, t301, t311)\n", - " # t302 = prims.add(t301, t251) # t302: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t303 = prims.mul(t302, t302) # t303: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t304 = prims.sum(t303, (2,)) # t304: \"cuda:0 f32[1, 512]\"\n", - " # t305 = prims.broadcast_in_dim(t304, [1, 512, 1], [0, 1]) # t305: \"cuda:0 f32[1, 512, 1]\"\n", - " # t306 = prims.div(t305, 4096.0) # t306: \"cuda:0 f32[1, 512, 1]\"\n", - " # t307 = prims.add(t306, 1e-05) # t307: \"cuda:0 f32[1, 512, 1]\"\n", - " # t308 = prims.rsqrt(t307) # t308: \"cuda:0 f32[1, 512, 1]\"\n", - " # t309 = prims.broadcast_in_dim(t308, (1, 512, 4096), (0, 1, 2)) # t309: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t310 = prims.mul(t302, t309) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t312 = prims.mul(t310, t311) # t312: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t301\n", - " t314 = torch.nn.functional.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t314 = ltorch.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t314 = prims.linear(t312, t14, None) # t314: \"cuda:0 f32[1, 512, 11008]\"\n", - " t313 = torch.nn.functional.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t313 = ltorch.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t313 = prims.linear(t312, t10, None) # t313: \"cuda:0 f32[1, 512, 11008]\"\n", - " [t320] = nvFusion19(t313, t314)\n", - " # t315 = prims.neg(t313) # t315: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t316 = prims.exp(t315) # t316: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t317 = prims.add(1.0, t316) # t317: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t318 = prims.reciprocal(t317) # t318: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t319 = prims.mul(t313, t318) # t319: \"cuda:0 f32[1, 512, 11008]\"\n", - " # t320 = prims.mul(t319, t314) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", - " t321 = torch.nn.functional.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t321 = ltorch.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t321 = prims.linear(t320, t32, None) # t321: \"cuda:0 f32[1, 512, 4096]\"\n", - " [t322, t328, t332] = nvFusion20(t302, t321, t331)\n", - " # t322 = prims.add(t321, t302) # t322: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t323 = prims.mul(t322, t322) # t323: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t324 = prims.sum(t323, (2,)) # t324: \"cuda:0 f32[1, 512]\"\n", - " # t325 = prims.broadcast_in_dim(t324, [1, 512, 1], [0, 1]) # t325: \"cuda:0 f32[1, 512, 1]\"\n", - " # t326 = prims.div(t325, 4096.0) # t326: \"cuda:0 f32[1, 512, 1]\"\n", - " # t327 = prims.add(t326, 1e-05) # t327: \"cuda:0 f32[1, 512, 1]\"\n", - " # t328 = prims.rsqrt(t327) # t328: \"cuda:0 f32[1, 512, 1]\"\n", - " # t329 = prims.broadcast_in_dim(t328, (1, 512, 4096), (0, 1, 2)) # t329: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t330 = prims.mul(t322, t329) # t330: \"cuda:0 f32[1, 512, 4096]\"\n", - " # t332 = prims.mul(t330, t331) # t332: \"cuda:0 f32[1, 512, 4096]\"\n", - " del t321\n", - " t333 = torch.nn.functional.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", - " # t333 = ltorch.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", - " # t333 = prims.linear(t332, t15, None) # t333: \"cuda:0 f32[1, 512, 32000]\"\n", - " return {'output': t333, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33], 'flat_output': (t333,)}, ((t0, t10, t100, t101, t107, t109, t11, t115, t118, t119, t12, t128, t13, t14, t15, t150, t152, t153, t154, t155, t156, t158, t160, t166, t169, t170, t171, t172, t178, t180, t186, t189, t190, t199, t221, t223, t224, t225, t226, t227, t229, t231, t237, t240, t241, t242, t243, t249, t25, t251, t257, t26, t260, t261, t27, t270, t28, t29, t292, t294, t295, t296, t297, t298, t3, t30, t300, t302, t308, t31, t311, t312, t313, t314, t32, t320, t322, t328, t331, t332, t38, t4, t44, t47, t48, t5, t57, t6, t63, t65, t7, t79, t8, t81, t82, t83, t84, t85, t87, t89, t9, t95, t98, t99), (False, True, True, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2))" + " # t275 = prims.convert_element_type(t274, dtypes.bfloat16) # t275: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t257, t272\n", + " t261 = torch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t261 = ltorch.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t261 = prims.cat((t260, t256), -1) # t261: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t260, t256\n", + " t276 = torch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t276 = ltorch.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t276 = prims.cat((t275, t271), -1) # t276: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t275, t271\n", + " [t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)\n", + " # t263 = prims.convert_element_type(t255, dtypes.float32) # t263: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t278 = prims.convert_element_type(t270, dtypes.float32) # t278: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t264 = prims.mul(t263, t154) # t264: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t266 = prims.convert_element_type(t261, dtypes.float32) # t266: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t267 = prims.mul(t266, t157) # t267: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t268 = prims.add(t264, t267) # t268: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t279 = prims.mul(t278, t154) # t279: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t281 = prims.convert_element_type(t276, dtypes.float32) # t281: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t282 = prims.mul(t281, t157) # t282: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t283 = prims.add(t279, t282) # t283: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t284 = prims.convert_element_type(t283, dtypes.bfloat16) # t284: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t255, t261, t270, t276\n", + " t288 = torch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t288 = ltorch.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t288 = prims.cat((t284, t287), -1) # t288: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t284, t287\n", + " t286 = torch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t286 = ltorch.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t286 = prims.cat((t269, t285), -1) # t286: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t269, t285\n", + " (t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)\n", + " t296 = torch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t296 = ltorch.permute(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t296 = prims.transpose(t289, (0, 2, 1, 3)) # t296: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t297 = torch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t297 = ltorch.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t297 = prims.reshape(t296, (1, 512, 4096)) # t297: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t296\n", + " t298 = torch.nn.functional.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t298 = ltorch.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t298 = prims.linear(t297, t87, None) # t298: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t302, t309, t317] = nvFusion8(t230, t298, t313)\n", + " # t300 = prims.convert_element_type(t230, dtypes.float32) # t300: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t299 = prims.convert_element_type(t298, dtypes.float32) # t299: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t301 = prims.add(t299, t300) # t301: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t302 = prims.convert_element_type(t301, dtypes.bfloat16) # t302: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t304 = prims.mul(t301, t301) # t304: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t305 = prims.sum(t304, (2,)) # t305: \"cuda:0 f32[1, 512]\"\n", + " # t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1]) # t306: \"cuda:0 f32[1, 512, 1]\"\n", + " # t307 = prims.div(t306, 4096.0) # t307: \"cuda:0 f32[1, 512, 1]\"\n", + " # t308 = prims.add(t307, 1e-05) # t308: \"cuda:0 f32[1, 512, 1]\"\n", + " # t309 = prims.rsqrt(t308) # t309: \"cuda:0 f32[1, 512, 1]\"\n", + " # t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2)) # t310: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t311 = prims.mul(t301, t310) # t311: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t315 = prims.convert_element_type(t313, dtypes.float32) # t315: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t316 = prims.mul(t311, t315) # t316: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t317 = prims.convert_element_type(t316, dtypes.bfloat16) # t317: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t318 = torch.nn.functional.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t318 = ltorch.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t318 = prims.linear(t317, t20, None) # t318: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t319 = torch.nn.functional.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t319 = ltorch.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t319 = prims.linear(t317, t36, None) # t319: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t333] = nvFusion9(t318, t319)\n", + " # t320 = prims.convert_element_type(t318, dtypes.float32) # t320: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t321 = prims.neg(t320) # t321: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t322 = prims.exp(t321) # t322: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t323 = prims.add(1.0, t322) # t323: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t324 = prims.reciprocal(t323) # t324: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t328 = prims.mul(t320, t324) # t328: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t331 = prims.convert_element_type(t319, dtypes.float32) # t331: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t332 = prims.mul(t328, t331) # t332: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t333 = prims.convert_element_type(t332, dtypes.bfloat16) # t333: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t334 = torch.nn.functional.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t334 = ltorch.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t334 = prims.linear(t333, t88, None) # t334: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t338, t345, t353] = nvFusion10(t302, t334, t349)\n", + " # t336 = prims.convert_element_type(t302, dtypes.float32) # t336: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t335 = prims.convert_element_type(t334, dtypes.float32) # t335: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t337 = prims.add(t335, t336) # t337: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t340 = prims.mul(t337, t337) # t340: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t341 = prims.sum(t340, (2,)) # t341: \"cuda:0 f32[1, 512]\"\n", + " # t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1]) # t342: \"cuda:0 f32[1, 512, 1]\"\n", + " # t343 = prims.div(t342, 4096.0) # t343: \"cuda:0 f32[1, 512, 1]\"\n", + " # t344 = prims.add(t343, 1e-05) # t344: \"cuda:0 f32[1, 512, 1]\"\n", + " # t345 = prims.rsqrt(t344) # t345: \"cuda:0 f32[1, 512, 1]\"\n", + " # t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2)) # t346: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t347 = prims.mul(t337, t346) # t347: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t351 = prims.convert_element_type(t349, dtypes.float32) # t351: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t352 = prims.mul(t347, t351) # t352: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t354 = torch.nn.functional.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t354 = ltorch.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t354 = prims.linear(t353, t5, None) # t354: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t355 = torch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t355 = prims.reshape(t354, (1, 512, 32, 3, 128)) # t355: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t354\n", + " t356 = torch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t356 = ltorch.permute(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t356 = prims.transpose(t355, (0, 2, 3, 1, 4)) # t356: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t355\n", + " (t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)\n", + " # (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)\n", + " # t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t357: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t358: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t359: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t356\n", + " t360 = torch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t360 = ltorch.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t360 = prims.reshape(t357, (1, 32, 512, 128)) # t360: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t357\n", + " t361 = torch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t361 = ltorch.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t361 = prims.reshape(t358, (1, 32, 512, 128)) # t361: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t358\n", + " t362 = torch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t362 = ltorch.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t362 = prims.reshape(t359, (1, 32, 512, 128)) # t362: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t359\n", + " t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t363: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t378: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t393: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t360\n", + " t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t395: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t361\n", + " t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t364: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t365: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t379: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t380: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t368, t383] = nvFusion11(t363, t365, t378, t380)\n", + " # t366 = prims.convert_element_type(t365, dtypes.float32) # t366: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t367 = prims.neg(t366) # t367: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t368 = prims.convert_element_type(t367, dtypes.bfloat16) # t368: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t381 = prims.convert_element_type(t380, dtypes.float32) # t381: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t382 = prims.neg(t381) # t382: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t383 = prims.convert_element_type(t382, dtypes.bfloat16) # t383: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t365, t380\n", + " t369 = torch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t369 = ltorch.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t369 = prims.cat((t368, t364), -1) # t369: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t368, t364\n", + " t384 = torch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t384 = ltorch.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t384 = prims.cat((t383, t379), -1) # t384: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t383, t379\n", + " [t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)\n", + " # t371 = prims.convert_element_type(t363, dtypes.float32) # t371: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t386 = prims.convert_element_type(t378, dtypes.float32) # t386: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t372 = prims.mul(t371, t154) # t372: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t374 = prims.convert_element_type(t369, dtypes.float32) # t374: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t375 = prims.mul(t374, t157) # t375: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t376 = prims.add(t372, t375) # t376: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t377 = prims.convert_element_type(t376, dtypes.bfloat16) # t377: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t387 = prims.mul(t386, t154) # t387: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t389 = prims.convert_element_type(t384, dtypes.float32) # t389: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t390 = prims.mul(t389, t157) # t390: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t391 = prims.add(t387, t390) # t391: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t363, t369, t378, t384\n", + " t394 = torch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t394 = ltorch.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t394 = prims.cat((t377, t393), -1) # t394: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t377, t393\n", + " t396 = torch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t396 = ltorch.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t396 = prims.cat((t392, t395), -1) # t396: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t392, t395\n", + " (t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)\n", + " t404 = torch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t404 = ltorch.permute(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t404 = prims.transpose(t397, (0, 2, 1, 3)) # t404: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t405 = torch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t405 = ltorch.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t405 = prims.reshape(t404, (1, 512, 4096)) # t405: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t404\n", + " t406 = torch.nn.functional.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t406 = ltorch.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t406 = prims.linear(t405, t89, None) # t406: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t410, t417, t425] = nvFusion13(t338, t406, t421)\n", + " # t408 = prims.convert_element_type(t338, dtypes.float32) # t408: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t407 = prims.convert_element_type(t406, dtypes.float32) # t407: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t409 = prims.add(t407, t408) # t409: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t410 = prims.convert_element_type(t409, dtypes.bfloat16) # t410: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t412 = prims.mul(t409, t409) # t412: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t413 = prims.sum(t412, (2,)) # t413: \"cuda:0 f32[1, 512]\"\n", + " # t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1]) # t414: \"cuda:0 f32[1, 512, 1]\"\n", + " # t415 = prims.div(t414, 4096.0) # t415: \"cuda:0 f32[1, 512, 1]\"\n", + " # t416 = prims.add(t415, 1e-05) # t416: \"cuda:0 f32[1, 512, 1]\"\n", + " # t417 = prims.rsqrt(t416) # t417: \"cuda:0 f32[1, 512, 1]\"\n", + " # t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2)) # t418: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t419 = prims.mul(t409, t418) # t419: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t423 = prims.convert_element_type(t421, dtypes.float32) # t423: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t424 = prims.mul(t419, t423) # t424: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t425 = prims.convert_element_type(t424, dtypes.bfloat16) # t425: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t426 = torch.nn.functional.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t426 = ltorch.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t426 = prims.linear(t425, t21, None) # t426: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t427 = torch.nn.functional.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t427 = ltorch.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t427 = prims.linear(t425, t37, None) # t427: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t441] = nvFusion14(t426, t427)\n", + " # t428 = prims.convert_element_type(t426, dtypes.float32) # t428: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t429 = prims.neg(t428) # t429: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t430 = prims.exp(t429) # t430: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t431 = prims.add(1.0, t430) # t431: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t432 = prims.reciprocal(t431) # t432: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t436 = prims.mul(t428, t432) # t436: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t439 = prims.convert_element_type(t427, dtypes.float32) # t439: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t440 = prims.mul(t436, t439) # t440: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t441 = prims.convert_element_type(t440, dtypes.bfloat16) # t441: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t442 = torch.nn.functional.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t442 = ltorch.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t442 = prims.linear(t441, t90, None) # t442: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t446, t453, t461] = nvFusion15(t410, t442, t457)\n", + " # t444 = prims.convert_element_type(t410, dtypes.float32) # t444: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t443 = prims.convert_element_type(t442, dtypes.float32) # t443: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t445 = prims.add(t443, t444) # t445: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t446 = prims.convert_element_type(t445, dtypes.bfloat16) # t446: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t448 = prims.mul(t445, t445) # t448: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t449 = prims.sum(t448, (2,)) # t449: \"cuda:0 f32[1, 512]\"\n", + " # t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1]) # t450: \"cuda:0 f32[1, 512, 1]\"\n", + " # t451 = prims.div(t450, 4096.0) # t451: \"cuda:0 f32[1, 512, 1]\"\n", + " # t452 = prims.add(t451, 1e-05) # t452: \"cuda:0 f32[1, 512, 1]\"\n", + " # t453 = prims.rsqrt(t452) # t453: \"cuda:0 f32[1, 512, 1]\"\n", + " # t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2)) # t454: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t455 = prims.mul(t445, t454) # t455: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t459 = prims.convert_element_type(t457, dtypes.float32) # t459: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t460 = prims.mul(t455, t459) # t460: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t461 = prims.convert_element_type(t460, dtypes.bfloat16) # t461: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t462 = torch.nn.functional.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t462 = ltorch.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t462 = prims.linear(t461, t6, None) # t462: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t463 = torch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t463 = prims.reshape(t462, (1, 512, 32, 3, 128)) # t463: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t462\n", + " t464 = torch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t464 = ltorch.permute(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t464 = prims.transpose(t463, (0, 2, 3, 1, 4)) # t464: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t463\n", + " (t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)\n", + " # (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)\n", + " # t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t465: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t466: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t467: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t464\n", + " t468 = torch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t468 = ltorch.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t468 = prims.reshape(t465, (1, 32, 512, 128)) # t468: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t465\n", + " t469 = torch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t469 = ltorch.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t469 = prims.reshape(t466, (1, 32, 512, 128)) # t469: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t466\n", + " t470 = torch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t470 = ltorch.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t470 = prims.reshape(t467, (1, 32, 512, 128)) # t470: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t467\n", + " t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t471: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t486: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t501: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t468\n", + " t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t503: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t469\n", + " t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t472: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t473: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t487: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t488: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t476, t491] = nvFusion16(t471, t473, t486, t488)\n", + " # t474 = prims.convert_element_type(t473, dtypes.float32) # t474: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t475 = prims.neg(t474) # t475: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t476 = prims.convert_element_type(t475, dtypes.bfloat16) # t476: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t489 = prims.convert_element_type(t488, dtypes.float32) # t489: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t490 = prims.neg(t489) # t490: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t491 = prims.convert_element_type(t490, dtypes.bfloat16) # t491: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t473, t488\n", + " t477 = torch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t477 = ltorch.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t477 = prims.cat((t476, t472), -1) # t477: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t476, t472\n", + " t492 = torch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t492 = ltorch.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t492 = prims.cat((t491, t487), -1) # t492: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t491, t487\n", + " [t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)\n", + " # t479 = prims.convert_element_type(t471, dtypes.float32) # t479: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t494 = prims.convert_element_type(t486, dtypes.float32) # t494: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t480 = prims.mul(t479, t154) # t480: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t482 = prims.convert_element_type(t477, dtypes.float32) # t482: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t483 = prims.mul(t482, t157) # t483: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t484 = prims.add(t480, t483) # t484: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t485 = prims.convert_element_type(t484, dtypes.bfloat16) # t485: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t495 = prims.mul(t494, t154) # t495: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t497 = prims.convert_element_type(t492, dtypes.float32) # t497: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t498 = prims.mul(t497, t157) # t498: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t499 = prims.add(t495, t498) # t499: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t500 = prims.convert_element_type(t499, dtypes.bfloat16) # t500: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t471, t477, t486, t492\n", + " t502 = torch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t502 = ltorch.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t502 = prims.cat((t485, t501), -1) # t502: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t485, t501\n", + " t504 = torch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t504 = ltorch.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t504 = prims.cat((t500, t503), -1) # t504: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t500, t503\n", + " (t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)\n", + " t512 = torch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t512 = ltorch.permute(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t512 = prims.transpose(t505, (0, 2, 1, 3)) # t512: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t513 = torch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t513 = ltorch.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t513 = prims.reshape(t512, (1, 512, 4096)) # t513: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t512\n", + " t514 = torch.nn.functional.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t514 = ltorch.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t514 = prims.linear(t513, t91, None) # t514: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t518, t525, t533] = nvFusion18(t446, t514, t529)\n", + " # t516 = prims.convert_element_type(t446, dtypes.float32) # t516: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t515 = prims.convert_element_type(t514, dtypes.float32) # t515: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t517 = prims.add(t515, t516) # t517: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t518 = prims.convert_element_type(t517, dtypes.bfloat16) # t518: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t520 = prims.mul(t517, t517) # t520: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t521 = prims.sum(t520, (2,)) # t521: \"cuda:0 f32[1, 512]\"\n", + " # t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1]) # t522: \"cuda:0 f32[1, 512, 1]\"\n", + " # t523 = prims.div(t522, 4096.0) # t523: \"cuda:0 f32[1, 512, 1]\"\n", + " # t524 = prims.add(t523, 1e-05) # t524: \"cuda:0 f32[1, 512, 1]\"\n", + " # t525 = prims.rsqrt(t524) # t525: \"cuda:0 f32[1, 512, 1]\"\n", + " # t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2)) # t526: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t527 = prims.mul(t517, t526) # t527: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t531 = prims.convert_element_type(t529, dtypes.float32) # t531: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t532 = prims.mul(t527, t531) # t532: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t534 = torch.nn.functional.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t534 = ltorch.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t534 = prims.linear(t533, t22, None) # t534: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t535 = torch.nn.functional.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t535 = ltorch.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t535 = prims.linear(t533, t38, None) # t535: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t549] = nvFusion19(t534, t535)\n", + " # t536 = prims.convert_element_type(t534, dtypes.float32) # t536: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t537 = prims.neg(t536) # t537: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t538 = prims.exp(t537) # t538: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t539 = prims.add(1.0, t538) # t539: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t540 = prims.reciprocal(t539) # t540: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t544 = prims.mul(t536, t540) # t544: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t547 = prims.convert_element_type(t535, dtypes.float32) # t547: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t548 = prims.mul(t544, t547) # t548: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t549 = prims.convert_element_type(t548, dtypes.bfloat16) # t549: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t550 = torch.nn.functional.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t550 = ltorch.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t550 = prims.linear(t549, t92, None) # t550: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t554, t561, t569] = nvFusion20(t518, t550, t565)\n", + " # t552 = prims.convert_element_type(t518, dtypes.float32) # t552: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t551 = prims.convert_element_type(t550, dtypes.float32) # t551: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t553 = prims.add(t551, t552) # t553: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t556 = prims.mul(t553, t553) # t556: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t557 = prims.sum(t556, (2,)) # t557: \"cuda:0 f32[1, 512]\"\n", + " # t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1]) # t558: \"cuda:0 f32[1, 512, 1]\"\n", + " # t559 = prims.div(t558, 4096.0) # t559: \"cuda:0 f32[1, 512, 1]\"\n", + " # t560 = prims.add(t559, 1e-05) # t560: \"cuda:0 f32[1, 512, 1]\"\n", + " # t561 = prims.rsqrt(t560) # t561: \"cuda:0 f32[1, 512, 1]\"\n", + " # t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2)) # t562: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t563 = prims.mul(t553, t562) # t563: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t567 = prims.convert_element_type(t565, dtypes.float32) # t567: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t568 = prims.mul(t563, t567) # t568: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t569 = prims.convert_element_type(t568, dtypes.bfloat16) # t569: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t570 = torch.nn.functional.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t570 = ltorch.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t570 = prims.linear(t569, t7, None) # t570: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t571 = torch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t571 = prims.reshape(t570, (1, 512, 32, 3, 128)) # t571: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t570\n", + " t572 = torch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t572 = ltorch.permute(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t572 = prims.transpose(t571, (0, 2, 3, 1, 4)) # t572: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t571\n", + " (t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)\n", + " # (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)\n", + " # t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t573: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t574: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t575: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t572\n", + " t576 = torch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t576 = ltorch.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t576 = prims.reshape(t573, (1, 32, 512, 128)) # t576: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t573\n", + " t577 = torch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t577 = ltorch.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t577 = prims.reshape(t574, (1, 32, 512, 128)) # t577: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t574\n", + " t578 = torch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t578 = ltorch.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t578 = prims.reshape(t575, (1, 32, 512, 128)) # t578: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t575\n", + " t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t579: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t594: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t609: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t576\n", + " t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t611: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t577\n", + " t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t580: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t581: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t595: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t596: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t584, t599] = nvFusion21(t579, t581, t594, t596)\n", + " # t582 = prims.convert_element_type(t581, dtypes.float32) # t582: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t583 = prims.neg(t582) # t583: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t584 = prims.convert_element_type(t583, dtypes.bfloat16) # t584: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t597 = prims.convert_element_type(t596, dtypes.float32) # t597: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t598 = prims.neg(t597) # t598: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t599 = prims.convert_element_type(t598, dtypes.bfloat16) # t599: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t581, t596\n", + " t600 = torch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t600 = ltorch.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t600 = prims.cat((t599, t595), -1) # t600: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t599, t595\n", + " t585 = torch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t585 = ltorch.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t585 = prims.cat((t584, t580), -1) # t585: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t584, t580\n", + " [t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)\n", + " # t587 = prims.convert_element_type(t579, dtypes.float32) # t587: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t602 = prims.convert_element_type(t594, dtypes.float32) # t602: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t603 = prims.mul(t602, t154) # t603: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t605 = prims.convert_element_type(t600, dtypes.float32) # t605: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t606 = prims.mul(t605, t157) # t606: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t607 = prims.add(t603, t606) # t607: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t608 = prims.convert_element_type(t607, dtypes.bfloat16) # t608: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t588 = prims.mul(t587, t154) # t588: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t590 = prims.convert_element_type(t585, dtypes.float32) # t590: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t591 = prims.mul(t590, t157) # t591: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t592 = prims.add(t588, t591) # t592: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t593 = prims.convert_element_type(t592, dtypes.bfloat16) # t593: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t579, t585, t594, t600\n", + " t612 = torch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t612 = ltorch.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t612 = prims.cat((t608, t611), -1) # t612: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t608, t611\n", + " t610 = torch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t610 = ltorch.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t610 = prims.cat((t593, t609), -1) # t610: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t593, t609\n", + " (t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)\n", + " t620 = torch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t620 = ltorch.permute(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t620 = prims.transpose(t613, (0, 2, 1, 3)) # t620: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t621 = torch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t621 = ltorch.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t621 = prims.reshape(t620, (1, 512, 4096)) # t621: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t620\n", + " t622 = torch.nn.functional.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t622 = ltorch.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t622 = prims.linear(t621, t93, None) # t622: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t626, t633, t641] = nvFusion23(t554, t622, t637)\n", + " # t624 = prims.convert_element_type(t554, dtypes.float32) # t624: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t623 = prims.convert_element_type(t622, dtypes.float32) # t623: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t625 = prims.add(t623, t624) # t625: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t626 = prims.convert_element_type(t625, dtypes.bfloat16) # t626: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t628 = prims.mul(t625, t625) # t628: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t629 = prims.sum(t628, (2,)) # t629: \"cuda:0 f32[1, 512]\"\n", + " # t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1]) # t630: \"cuda:0 f32[1, 512, 1]\"\n", + " # t631 = prims.div(t630, 4096.0) # t631: \"cuda:0 f32[1, 512, 1]\"\n", + " # t632 = prims.add(t631, 1e-05) # t632: \"cuda:0 f32[1, 512, 1]\"\n", + " # t633 = prims.rsqrt(t632) # t633: \"cuda:0 f32[1, 512, 1]\"\n", + " # t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2)) # t634: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t635 = prims.mul(t625, t634) # t635: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t639 = prims.convert_element_type(t637, dtypes.float32) # t639: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t640 = prims.mul(t635, t639) # t640: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t641 = prims.convert_element_type(t640, dtypes.bfloat16) # t641: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t643 = torch.nn.functional.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t643 = ltorch.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t643 = prims.linear(t641, t39, None) # t643: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t642 = torch.nn.functional.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t642 = ltorch.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t642 = prims.linear(t641, t23, None) # t642: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t657] = nvFusion24(t642, t643)\n", + " # t644 = prims.convert_element_type(t642, dtypes.float32) # t644: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t645 = prims.neg(t644) # t645: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t646 = prims.exp(t645) # t646: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t647 = prims.add(1.0, t646) # t647: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t648 = prims.reciprocal(t647) # t648: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t652 = prims.mul(t644, t648) # t652: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t655 = prims.convert_element_type(t643, dtypes.float32) # t655: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t656 = prims.mul(t652, t655) # t656: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t657 = prims.convert_element_type(t656, dtypes.bfloat16) # t657: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t658 = torch.nn.functional.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t658 = ltorch.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t658 = prims.linear(t657, t94, None) # t658: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t662, t669, t677] = nvFusion25(t626, t658, t673)\n", + " # t660 = prims.convert_element_type(t626, dtypes.float32) # t660: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t659 = prims.convert_element_type(t658, dtypes.float32) # t659: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t661 = prims.add(t659, t660) # t661: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t662 = prims.convert_element_type(t661, dtypes.bfloat16) # t662: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t664 = prims.mul(t661, t661) # t664: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t665 = prims.sum(t664, (2,)) # t665: \"cuda:0 f32[1, 512]\"\n", + " # t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1]) # t666: \"cuda:0 f32[1, 512, 1]\"\n", + " # t667 = prims.div(t666, 4096.0) # t667: \"cuda:0 f32[1, 512, 1]\"\n", + " # t668 = prims.add(t667, 1e-05) # t668: \"cuda:0 f32[1, 512, 1]\"\n", + " # t669 = prims.rsqrt(t668) # t669: \"cuda:0 f32[1, 512, 1]\"\n", + " # t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2)) # t670: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t671 = prims.mul(t661, t670) # t671: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t675 = prims.convert_element_type(t673, dtypes.float32) # t675: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t676 = prims.mul(t671, t675) # t676: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t677 = prims.convert_element_type(t676, dtypes.bfloat16) # t677: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t678 = torch.nn.functional.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t678 = ltorch.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t678 = prims.linear(t677, t8, None) # t678: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t679 = torch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t679 = prims.reshape(t678, (1, 512, 32, 3, 128)) # t679: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t678\n", + " t680 = torch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t680 = ltorch.permute(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t680 = prims.transpose(t679, (0, 2, 3, 1, 4)) # t680: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t679\n", + " (t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)\n", + " # (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)\n", + " # t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t681: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t682: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t683: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t680\n", + " t684 = torch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t684 = ltorch.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t684 = prims.reshape(t681, (1, 32, 512, 128)) # t684: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t681\n", + " t685 = torch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t685 = ltorch.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t685 = prims.reshape(t682, (1, 32, 512, 128)) # t685: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t682\n", + " t686 = torch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t686 = ltorch.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t686 = prims.reshape(t683, (1, 32, 512, 128)) # t686: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t683\n", + " t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t687: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t702: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t717: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t684\n", + " t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t719: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t685\n", + " t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t688: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t689: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t703: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t704: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t692, t707] = nvFusion26(t687, t689, t702, t704)\n", + " # t690 = prims.convert_element_type(t689, dtypes.float32) # t690: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t691 = prims.neg(t690) # t691: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t692 = prims.convert_element_type(t691, dtypes.bfloat16) # t692: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t705 = prims.convert_element_type(t704, dtypes.float32) # t705: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t706 = prims.neg(t705) # t706: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t707 = prims.convert_element_type(t706, dtypes.bfloat16) # t707: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t689, t704\n", + " t708 = torch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t708 = ltorch.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t708 = prims.cat((t707, t703), -1) # t708: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t707, t703\n", + " t693 = torch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t693 = ltorch.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t693 = prims.cat((t692, t688), -1) # t693: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t692, t688\n", + " [t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)\n", + " # t695 = prims.convert_element_type(t687, dtypes.float32) # t695: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t710 = prims.convert_element_type(t702, dtypes.float32) # t710: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t711 = prims.mul(t710, t154) # t711: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t713 = prims.convert_element_type(t708, dtypes.float32) # t713: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t714 = prims.mul(t713, t157) # t714: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t715 = prims.add(t711, t714) # t715: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t716 = prims.convert_element_type(t715, dtypes.bfloat16) # t716: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t696 = prims.mul(t695, t154) # t696: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t698 = prims.convert_element_type(t693, dtypes.float32) # t698: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t699 = prims.mul(t698, t157) # t699: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t700 = prims.add(t696, t699) # t700: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t701 = prims.convert_element_type(t700, dtypes.bfloat16) # t701: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t687, t693, t702, t708\n", + " t720 = torch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t720 = ltorch.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t720 = prims.cat((t716, t719), -1) # t720: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t716, t719\n", + " t718 = torch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t718 = ltorch.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t718 = prims.cat((t701, t717), -1) # t718: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t701, t717\n", + " (t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)\n", + " t728 = torch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t728 = ltorch.permute(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t728 = prims.transpose(t721, (0, 2, 1, 3)) # t728: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t729 = torch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t729 = ltorch.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t729 = prims.reshape(t728, (1, 512, 4096)) # t729: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t728\n", + " t730 = torch.nn.functional.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t730 = ltorch.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t730 = prims.linear(t729, t95, None) # t730: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t734, t741, t749] = nvFusion28(t662, t730, t745)\n", + " # t732 = prims.convert_element_type(t662, dtypes.float32) # t732: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t733 = prims.add(t731, t732) # t733: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t734 = prims.convert_element_type(t733, dtypes.bfloat16) # t734: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t736 = prims.mul(t733, t733) # t736: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t737 = prims.sum(t736, (2,)) # t737: \"cuda:0 f32[1, 512]\"\n", + " # t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1]) # t738: \"cuda:0 f32[1, 512, 1]\"\n", + " # t739 = prims.div(t738, 4096.0) # t739: \"cuda:0 f32[1, 512, 1]\"\n", + " # t740 = prims.add(t739, 1e-05) # t740: \"cuda:0 f32[1, 512, 1]\"\n", + " # t741 = prims.rsqrt(t740) # t741: \"cuda:0 f32[1, 512, 1]\"\n", + " # t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2)) # t742: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t743 = prims.mul(t733, t742) # t743: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t747 = prims.convert_element_type(t745, dtypes.float32) # t747: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t748 = prims.mul(t743, t747) # t748: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t750 = torch.nn.functional.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t750 = ltorch.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t750 = prims.linear(t749, t24, None) # t750: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t751 = torch.nn.functional.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t751 = ltorch.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t751 = prims.linear(t749, t40, None) # t751: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t765] = nvFusion29(t750, t751)\n", + " # t752 = prims.convert_element_type(t750, dtypes.float32) # t752: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t753 = prims.neg(t752) # t753: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t754 = prims.exp(t753) # t754: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t755 = prims.add(1.0, t754) # t755: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t756 = prims.reciprocal(t755) # t756: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t760 = prims.mul(t752, t756) # t760: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t763 = prims.convert_element_type(t751, dtypes.float32) # t763: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t764 = prims.mul(t760, t763) # t764: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t765 = prims.convert_element_type(t764, dtypes.bfloat16) # t765: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t766 = torch.nn.functional.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t766 = ltorch.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t766 = prims.linear(t765, t96, None) # t766: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t770, t777, t785] = nvFusion30(t734, t766, t781)\n", + " # t768 = prims.convert_element_type(t734, dtypes.float32) # t768: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t767 = prims.convert_element_type(t766, dtypes.float32) # t767: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t769 = prims.add(t767, t768) # t769: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t770 = prims.convert_element_type(t769, dtypes.bfloat16) # t770: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t772 = prims.mul(t769, t769) # t772: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t773 = prims.sum(t772, (2,)) # t773: \"cuda:0 f32[1, 512]\"\n", + " # t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1]) # t774: \"cuda:0 f32[1, 512, 1]\"\n", + " # t775 = prims.div(t774, 4096.0) # t775: \"cuda:0 f32[1, 512, 1]\"\n", + " # t776 = prims.add(t775, 1e-05) # t776: \"cuda:0 f32[1, 512, 1]\"\n", + " # t777 = prims.rsqrt(t776) # t777: \"cuda:0 f32[1, 512, 1]\"\n", + " # t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2)) # t778: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t779 = prims.mul(t769, t778) # t779: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t783 = prims.convert_element_type(t781, dtypes.float32) # t783: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t784 = prims.mul(t779, t783) # t784: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t785 = prims.convert_element_type(t784, dtypes.bfloat16) # t785: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t786 = torch.nn.functional.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t786 = ltorch.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t786 = prims.linear(t785, t9, None) # t786: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t787 = torch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t787 = prims.reshape(t786, (1, 512, 32, 3, 128)) # t787: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t786\n", + " t788 = torch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t788 = ltorch.permute(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t788 = prims.transpose(t787, (0, 2, 3, 1, 4)) # t788: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t787\n", + " (t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)\n", + " # (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)\n", + " # t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t789: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t790: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t791: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t788\n", + " t792 = torch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t792 = ltorch.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t792 = prims.reshape(t789, (1, 32, 512, 128)) # t792: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t789\n", + " t793 = torch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t793 = ltorch.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t793 = prims.reshape(t790, (1, 32, 512, 128)) # t793: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t790\n", + " t794 = torch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t794 = ltorch.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t794 = prims.reshape(t791, (1, 32, 512, 128)) # t794: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t791\n", + " t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t795: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t810: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t825: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t792\n", + " t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t827: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t793\n", + " t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t796: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t797: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t811: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t812: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t800, t815] = nvFusion31(t795, t797, t810, t812)\n", + " # t798 = prims.convert_element_type(t797, dtypes.float32) # t798: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t799 = prims.neg(t798) # t799: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t800 = prims.convert_element_type(t799, dtypes.bfloat16) # t800: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t813 = prims.convert_element_type(t812, dtypes.float32) # t813: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t814 = prims.neg(t813) # t814: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t815 = prims.convert_element_type(t814, dtypes.bfloat16) # t815: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t797, t812\n", + " t816 = torch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t816 = ltorch.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t816 = prims.cat((t815, t811), -1) # t816: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t815, t811\n", + " t801 = torch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t801 = ltorch.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t801 = prims.cat((t800, t796), -1) # t801: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t800, t796\n", + " [t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)\n", + " # t803 = prims.convert_element_type(t795, dtypes.float32) # t803: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t818 = prims.convert_element_type(t810, dtypes.float32) # t818: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t819 = prims.mul(t818, t154) # t819: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t821 = prims.convert_element_type(t816, dtypes.float32) # t821: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t822 = prims.mul(t821, t157) # t822: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t823 = prims.add(t819, t822) # t823: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t824 = prims.convert_element_type(t823, dtypes.bfloat16) # t824: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t804 = prims.mul(t803, t154) # t804: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t806 = prims.convert_element_type(t801, dtypes.float32) # t806: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t807 = prims.mul(t806, t157) # t807: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t808 = prims.add(t804, t807) # t808: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t809 = prims.convert_element_type(t808, dtypes.bfloat16) # t809: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t795, t801, t810, t816\n", + " t828 = torch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t828 = ltorch.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t828 = prims.cat((t824, t827), -1) # t828: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t824, t827\n", + " t826 = torch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t826 = ltorch.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t826 = prims.cat((t809, t825), -1) # t826: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t809, t825\n", + " (t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)\n", + " t836 = torch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t836 = ltorch.permute(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t836 = prims.transpose(t829, (0, 2, 1, 3)) # t836: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t837 = torch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t837 = ltorch.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t837 = prims.reshape(t836, (1, 512, 4096)) # t837: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t836\n", + " t838 = torch.nn.functional.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t838 = ltorch.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t838 = prims.linear(t837, t97, None) # t838: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t842, t849, t857] = nvFusion33(t770, t838, t853)\n", + " # t840 = prims.convert_element_type(t770, dtypes.float32) # t840: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t839 = prims.convert_element_type(t838, dtypes.float32) # t839: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t841 = prims.add(t839, t840) # t841: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t842 = prims.convert_element_type(t841, dtypes.bfloat16) # t842: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t844 = prims.mul(t841, t841) # t844: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t845 = prims.sum(t844, (2,)) # t845: \"cuda:0 f32[1, 512]\"\n", + " # t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1]) # t846: \"cuda:0 f32[1, 512, 1]\"\n", + " # t847 = prims.div(t846, 4096.0) # t847: \"cuda:0 f32[1, 512, 1]\"\n", + " # t848 = prims.add(t847, 1e-05) # t848: \"cuda:0 f32[1, 512, 1]\"\n", + " # t849 = prims.rsqrt(t848) # t849: \"cuda:0 f32[1, 512, 1]\"\n", + " # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t851 = prims.mul(t841, t850) # t851: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t855 = prims.convert_element_type(t853, dtypes.float32) # t855: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t856 = prims.mul(t851, t855) # t856: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t857 = prims.convert_element_type(t856, dtypes.bfloat16) # t857: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t858 = torch.nn.functional.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t858 = ltorch.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t858 = prims.linear(t857, t25, None) # t858: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t859 = torch.nn.functional.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t859 = ltorch.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t859 = prims.linear(t857, t41, None) # t859: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t873] = nvFusion34(t858, t859)\n", + " # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t861 = prims.neg(t860) # t861: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t862 = prims.exp(t861) # t862: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t863 = prims.add(1.0, t862) # t863: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t864 = prims.reciprocal(t863) # t864: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t868 = prims.mul(t860, t864) # t868: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t871 = prims.convert_element_type(t859, dtypes.float32) # t871: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t872 = prims.mul(t868, t871) # t872: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t873 = prims.convert_element_type(t872, dtypes.bfloat16) # t873: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t874 = torch.nn.functional.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t874 = ltorch.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t874 = prims.linear(t873, t98, None) # t874: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t878, t885, t893] = nvFusion35(t842, t874, t889)\n", + " # t876 = prims.convert_element_type(t842, dtypes.float32) # t876: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t875 = prims.convert_element_type(t874, dtypes.float32) # t875: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t877 = prims.add(t875, t876) # t877: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t878 = prims.convert_element_type(t877, dtypes.bfloat16) # t878: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t880 = prims.mul(t877, t877) # t880: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t881 = prims.sum(t880, (2,)) # t881: \"cuda:0 f32[1, 512]\"\n", + " # t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1]) # t882: \"cuda:0 f32[1, 512, 1]\"\n", + " # t883 = prims.div(t882, 4096.0) # t883: \"cuda:0 f32[1, 512, 1]\"\n", + " # t884 = prims.add(t883, 1e-05) # t884: \"cuda:0 f32[1, 512, 1]\"\n", + " # t885 = prims.rsqrt(t884) # t885: \"cuda:0 f32[1, 512, 1]\"\n", + " # t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2)) # t886: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t887 = prims.mul(t877, t886) # t887: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t891 = prims.convert_element_type(t889, dtypes.float32) # t891: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t892 = prims.mul(t887, t891) # t892: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t893 = prims.convert_element_type(t892, dtypes.bfloat16) # t893: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t894 = torch.nn.functional.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t894 = ltorch.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t894 = prims.linear(t893, t10, None) # t894: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t895 = torch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t895 = prims.reshape(t894, (1, 512, 32, 3, 128)) # t895: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t894\n", + " t896 = torch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t896 = ltorch.permute(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t896 = prims.transpose(t895, (0, 2, 3, 1, 4)) # t896: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t895\n", + " (t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)\n", + " # (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)\n", + " # t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t897: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t898: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t899: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t896\n", + " t900 = torch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t900 = ltorch.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t900 = prims.reshape(t897, (1, 32, 512, 128)) # t900: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t897\n", + " t901 = torch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t901 = ltorch.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t901 = prims.reshape(t898, (1, 32, 512, 128)) # t901: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t898\n", + " t902 = torch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t902 = ltorch.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t902 = prims.reshape(t899, (1, 32, 512, 128)) # t902: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t899\n", + " t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t935: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t903: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t901\n", + " t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t933: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t900\n", + " t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t904: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t905: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t908, t923] = nvFusion36(t903, t905, t918, t920)\n", + " # t906 = prims.convert_element_type(t905, dtypes.float32) # t906: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t907 = prims.neg(t906) # t907: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t908 = prims.convert_element_type(t907, dtypes.bfloat16) # t908: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t921 = prims.convert_element_type(t920, dtypes.float32) # t921: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t922 = prims.neg(t921) # t922: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t905, t920\n", + " t924 = torch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t924 = ltorch.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t924 = prims.cat((t923, t919), -1) # t924: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t923, t919\n", + " t909 = torch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t909 = ltorch.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t909 = prims.cat((t908, t904), -1) # t909: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t908, t904\n", + " [t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)\n", + " # t911 = prims.convert_element_type(t903, dtypes.float32) # t911: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t926 = prims.convert_element_type(t918, dtypes.float32) # t926: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t927 = prims.mul(t926, t154) # t927: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t929 = prims.convert_element_type(t924, dtypes.float32) # t929: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t930 = prims.mul(t929, t157) # t930: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t931 = prims.add(t927, t930) # t931: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t932 = prims.convert_element_type(t931, dtypes.bfloat16) # t932: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t912 = prims.mul(t911, t154) # t912: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t914 = prims.convert_element_type(t909, dtypes.float32) # t914: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t915 = prims.mul(t914, t157) # t915: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t916 = prims.add(t912, t915) # t916: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t903, t909, t918, t924\n", + " t936 = torch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t936 = ltorch.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t936 = prims.cat((t932, t935), -1) # t936: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t932, t935\n", + " t934 = torch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t934 = ltorch.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t934 = prims.cat((t917, t933), -1) # t934: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t917, t933\n", + " (t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)\n", + " t944 = torch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t944 = ltorch.permute(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t944 = prims.transpose(t937, (0, 2, 1, 3)) # t944: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t945 = torch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t945 = ltorch.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t945 = prims.reshape(t944, (1, 512, 4096)) # t945: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t944\n", + " t946 = torch.nn.functional.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t946 = ltorch.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t946 = prims.linear(t945, t99, None) # t946: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t950, t957, t965] = nvFusion38(t878, t946, t961)\n", + " # t948 = prims.convert_element_type(t878, dtypes.float32) # t948: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t947 = prims.convert_element_type(t946, dtypes.float32) # t947: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t949 = prims.add(t947, t948) # t949: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t950 = prims.convert_element_type(t949, dtypes.bfloat16) # t950: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t952 = prims.mul(t949, t949) # t952: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t953 = prims.sum(t952, (2,)) # t953: \"cuda:0 f32[1, 512]\"\n", + " # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: \"cuda:0 f32[1, 512, 1]\"\n", + " # t955 = prims.div(t954, 4096.0) # t955: \"cuda:0 f32[1, 512, 1]\"\n", + " # t956 = prims.add(t955, 1e-05) # t956: \"cuda:0 f32[1, 512, 1]\"\n", + " # t957 = prims.rsqrt(t956) # t957: \"cuda:0 f32[1, 512, 1]\"\n", + " # t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2)) # t958: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t959 = prims.mul(t949, t958) # t959: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t963 = prims.convert_element_type(t961, dtypes.float32) # t963: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t964 = prims.mul(t959, t963) # t964: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t965 = prims.convert_element_type(t964, dtypes.bfloat16) # t965: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t967 = torch.nn.functional.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t967 = ltorch.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t967 = prims.linear(t965, t42, None) # t967: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t966 = torch.nn.functional.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t966 = ltorch.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t966 = prims.linear(t965, t26, None) # t966: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t981] = nvFusion39(t966, t967)\n", + " # t968 = prims.convert_element_type(t966, dtypes.float32) # t968: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t969 = prims.neg(t968) # t969: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t970 = prims.exp(t969) # t970: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t971 = prims.add(1.0, t970) # t971: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t972 = prims.reciprocal(t971) # t972: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t976 = prims.mul(t968, t972) # t976: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t979 = prims.convert_element_type(t967, dtypes.float32) # t979: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t980 = prims.mul(t976, t979) # t980: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t982 = torch.nn.functional.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t982 = ltorch.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t982 = prims.linear(t981, t100, None) # t982: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1001, t986, t993] = nvFusion40(t950, t982, t997)\n", + " # t984 = prims.convert_element_type(t950, dtypes.float32) # t984: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t983 = prims.convert_element_type(t982, dtypes.float32) # t983: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t985 = prims.add(t983, t984) # t985: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t988 = prims.mul(t985, t985) # t988: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t989 = prims.sum(t988, (2,)) # t989: \"cuda:0 f32[1, 512]\"\n", + " # t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1]) # t990: \"cuda:0 f32[1, 512, 1]\"\n", + " # t991 = prims.div(t990, 4096.0) # t991: \"cuda:0 f32[1, 512, 1]\"\n", + " # t992 = prims.add(t991, 1e-05) # t992: \"cuda:0 f32[1, 512, 1]\"\n", + " # t993 = prims.rsqrt(t992) # t993: \"cuda:0 f32[1, 512, 1]\"\n", + " # t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2)) # t994: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t995 = prims.mul(t985, t994) # t995: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t999 = prims.convert_element_type(t997, dtypes.float32) # t999: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1000 = prims.mul(t995, t999) # t1000: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1001 = prims.convert_element_type(t1000, dtypes.bfloat16) # t1001: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1002 = torch.nn.functional.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1002 = ltorch.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1002 = prims.linear(t1001, t11, None) # t1002: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128)) # t1003: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1002\n", + " t1004 = torch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4)) # t1004: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1003\n", + " (t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)\n", + " # (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)\n", + " # t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1005: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1006: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1007: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1004\n", + " t1008 = torch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1008 = ltorch.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1008 = prims.reshape(t1005, (1, 32, 512, 128)) # t1008: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1005\n", + " t1009 = torch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1009 = ltorch.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1009 = prims.reshape(t1006, (1, 32, 512, 128)) # t1009: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1006\n", + " t1010 = torch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1010 = ltorch.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1010 = prims.reshape(t1007, (1, 32, 512, 128)) # t1010: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1007\n", + " t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1026: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1041: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1043: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1009\n", + " t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1011: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1008\n", + " t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1027: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1028: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1013: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1012: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)\n", + " # t1014 = prims.convert_element_type(t1013, dtypes.float32) # t1014: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1015 = prims.neg(t1014) # t1015: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1016 = prims.convert_element_type(t1015, dtypes.bfloat16) # t1016: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1029 = prims.convert_element_type(t1028, dtypes.float32) # t1029: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1030 = prims.neg(t1029) # t1030: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1031 = prims.convert_element_type(t1030, dtypes.bfloat16) # t1031: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1013, t1028\n", + " t1032 = torch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1032 = ltorch.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1032 = prims.cat((t1031, t1027), -1) # t1032: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1031, t1027\n", + " t1017 = torch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1017 = ltorch.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1017 = prims.cat((t1016, t1012), -1) # t1017: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1016, t1012\n", + " [t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)\n", + " # t1019 = prims.convert_element_type(t1011, dtypes.float32) # t1019: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1034 = prims.convert_element_type(t1026, dtypes.float32) # t1034: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1020 = prims.mul(t1019, t154) # t1020: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1022 = prims.convert_element_type(t1017, dtypes.float32) # t1022: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1023 = prims.mul(t1022, t157) # t1023: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1024 = prims.add(t1020, t1023) # t1024: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1035 = prims.mul(t1034, t154) # t1035: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1037 = prims.convert_element_type(t1032, dtypes.float32) # t1037: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1038 = prims.mul(t1037, t157) # t1038: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1039 = prims.add(t1035, t1038) # t1039: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1040 = prims.convert_element_type(t1039, dtypes.bfloat16) # t1040: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1011, t1017, t1026, t1032\n", + " t1042 = torch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1042 = ltorch.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1042 = prims.cat((t1025, t1041), -1) # t1042: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1025, t1041\n", + " t1044 = torch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1044 = ltorch.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1044 = prims.cat((t1040, t1043), -1) # t1044: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1040, t1043\n", + " (t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)\n", + " t1052 = torch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1052 = ltorch.permute(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1052 = prims.transpose(t1045, (0, 2, 1, 3)) # t1052: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1053 = torch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1053 = ltorch.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1053 = prims.reshape(t1052, (1, 512, 4096)) # t1053: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1052\n", + " t1054 = torch.nn.functional.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1054 = ltorch.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1054 = prims.linear(t1053, t101, None) # t1054: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)\n", + " # t1056 = prims.convert_element_type(t986, dtypes.float32) # t1056: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1055 = prims.convert_element_type(t1054, dtypes.float32) # t1055: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1057 = prims.add(t1055, t1056) # t1057: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1058 = prims.convert_element_type(t1057, dtypes.bfloat16) # t1058: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1060 = prims.mul(t1057, t1057) # t1060: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1061 = prims.sum(t1060, (2,)) # t1061: \"cuda:0 f32[1, 512]\"\n", + " # t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1]) # t1062: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1063 = prims.div(t1062, 4096.0) # t1063: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1064 = prims.add(t1063, 1e-05) # t1064: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1065 = prims.rsqrt(t1064) # t1065: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2)) # t1066: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1067 = prims.mul(t1057, t1066) # t1067: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1071 = prims.convert_element_type(t1069, dtypes.float32) # t1071: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1072 = prims.mul(t1067, t1071) # t1072: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1073 = prims.convert_element_type(t1072, dtypes.bfloat16) # t1073: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1074 = torch.nn.functional.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1074 = ltorch.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1074 = prims.linear(t1073, t27, None) # t1074: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1075 = torch.nn.functional.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1075 = ltorch.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1075 = prims.linear(t1073, t43, None) # t1075: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1089] = nvFusion44(t1074, t1075)\n", + " # t1076 = prims.convert_element_type(t1074, dtypes.float32) # t1076: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1077 = prims.neg(t1076) # t1077: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1078 = prims.exp(t1077) # t1078: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1079 = prims.add(1.0, t1078) # t1079: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1080 = prims.reciprocal(t1079) # t1080: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1084 = prims.mul(t1076, t1080) # t1084: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1087 = prims.convert_element_type(t1075, dtypes.float32) # t1087: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1088 = prims.mul(t1084, t1087) # t1088: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1089 = prims.convert_element_type(t1088, dtypes.bfloat16) # t1089: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1090 = torch.nn.functional.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1090 = ltorch.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1090 = prims.linear(t1089, t102, None) # t1090: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)\n", + " # t1092 = prims.convert_element_type(t1058, dtypes.float32) # t1092: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1093 = prims.add(t1091, t1092) # t1093: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1094 = prims.convert_element_type(t1093, dtypes.bfloat16) # t1094: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1096 = prims.mul(t1093, t1093) # t1096: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1097 = prims.sum(t1096, (2,)) # t1097: \"cuda:0 f32[1, 512]\"\n", + " # t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1]) # t1098: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1099 = prims.div(t1098, 4096.0) # t1099: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1100 = prims.add(t1099, 1e-05) # t1100: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1101 = prims.rsqrt(t1100) # t1101: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2)) # t1102: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1103 = prims.mul(t1093, t1102) # t1103: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1107 = prims.convert_element_type(t1105, dtypes.float32) # t1107: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1108 = prims.mul(t1103, t1107) # t1108: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1109 = prims.convert_element_type(t1108, dtypes.bfloat16) # t1109: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1110 = torch.nn.functional.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1110 = ltorch.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1110 = prims.linear(t1109, t12, None) # t1110: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128)) # t1111: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1110\n", + " t1112 = torch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4)) # t1112: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1111\n", + " (t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)\n", + " # (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)\n", + " # t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1113: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1114: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1115: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1112\n", + " t1116 = torch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1116 = ltorch.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1116 = prims.reshape(t1113, (1, 32, 512, 128)) # t1116: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1113\n", + " t1117 = torch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1117 = ltorch.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1117 = prims.reshape(t1114, (1, 32, 512, 128)) # t1117: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1114\n", + " t1118 = torch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1118 = ltorch.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1118 = prims.reshape(t1115, (1, 32, 512, 128)) # t1118: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1115\n", + " t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1119: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1134: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1149: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1116\n", + " t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1151: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1117\n", + " t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1120: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1121: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1136: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1135: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)\n", + " # t1122 = prims.convert_element_type(t1121, dtypes.float32) # t1122: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1123 = prims.neg(t1122) # t1123: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1137 = prims.convert_element_type(t1136, dtypes.float32) # t1137: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1138 = prims.neg(t1137) # t1138: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1121, t1136\n", + " t1125 = torch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1125 = ltorch.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1125 = prims.cat((t1124, t1120), -1) # t1125: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1124, t1120\n", + " t1140 = torch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1140 = ltorch.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1140 = prims.cat((t1139, t1135), -1) # t1140: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1139, t1135\n", + " [t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)\n", + " # t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1142 = prims.convert_element_type(t1134, dtypes.float32) # t1142: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1128 = prims.mul(t1127, t154) # t1128: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1130 = prims.convert_element_type(t1125, dtypes.float32) # t1130: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1131 = prims.mul(t1130, t157) # t1131: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1132 = prims.add(t1128, t1131) # t1132: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1133 = prims.convert_element_type(t1132, dtypes.bfloat16) # t1133: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1143 = prims.mul(t1142, t154) # t1143: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1145 = prims.convert_element_type(t1140, dtypes.float32) # t1145: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1146 = prims.mul(t1145, t157) # t1146: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1147 = prims.add(t1143, t1146) # t1147: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1148 = prims.convert_element_type(t1147, dtypes.bfloat16) # t1148: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1119, t1125, t1134, t1140\n", + " t1152 = torch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1152 = ltorch.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1152 = prims.cat((t1148, t1151), -1) # t1152: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1148, t1151\n", + " t1150 = torch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1150 = ltorch.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1150 = prims.cat((t1133, t1149), -1) # t1150: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1133, t1149\n", + " (t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)\n", + " t1160 = torch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1160 = ltorch.permute(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1160 = prims.transpose(t1153, (0, 2, 1, 3)) # t1160: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1161 = torch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1161 = ltorch.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1161 = prims.reshape(t1160, (1, 512, 4096)) # t1161: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1160\n", + " t1162 = torch.nn.functional.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1162 = ltorch.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1162 = prims.linear(t1161, t103, None) # t1162: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)\n", + " # t1164 = prims.convert_element_type(t1094, dtypes.float32) # t1164: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1163 = prims.convert_element_type(t1162, dtypes.float32) # t1163: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1165 = prims.add(t1163, t1164) # t1165: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1166 = prims.convert_element_type(t1165, dtypes.bfloat16) # t1166: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1168 = prims.mul(t1165, t1165) # t1168: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1169 = prims.sum(t1168, (2,)) # t1169: \"cuda:0 f32[1, 512]\"\n", + " # t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1]) # t1170: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1171 = prims.div(t1170, 4096.0) # t1171: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1172 = prims.add(t1171, 1e-05) # t1172: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1173 = prims.rsqrt(t1172) # t1173: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2)) # t1174: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1175 = prims.mul(t1165, t1174) # t1175: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1179 = prims.convert_element_type(t1177, dtypes.float32) # t1179: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1180 = prims.mul(t1175, t1179) # t1180: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1181 = prims.convert_element_type(t1180, dtypes.bfloat16) # t1181: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1182 = torch.nn.functional.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1182 = ltorch.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1182 = prims.linear(t1181, t28, None) # t1182: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1183 = torch.nn.functional.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1183 = ltorch.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1183 = prims.linear(t1181, t44, None) # t1183: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1197] = nvFusion49(t1182, t1183)\n", + " # t1184 = prims.convert_element_type(t1182, dtypes.float32) # t1184: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1185 = prims.neg(t1184) # t1185: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1186 = prims.exp(t1185) # t1186: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1187 = prims.add(1.0, t1186) # t1187: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1188 = prims.reciprocal(t1187) # t1188: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1192 = prims.mul(t1184, t1188) # t1192: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1195 = prims.convert_element_type(t1183, dtypes.float32) # t1195: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1196 = prims.mul(t1192, t1195) # t1196: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1198 = torch.nn.functional.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1198 = ltorch.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1198 = prims.linear(t1197, t104, None) # t1198: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)\n", + " # t1200 = prims.convert_element_type(t1166, dtypes.float32) # t1200: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1199 = prims.convert_element_type(t1198, dtypes.float32) # t1199: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1201 = prims.add(t1199, t1200) # t1201: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1202 = prims.convert_element_type(t1201, dtypes.bfloat16) # t1202: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1204 = prims.mul(t1201, t1201) # t1204: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1205 = prims.sum(t1204, (2,)) # t1205: \"cuda:0 f32[1, 512]\"\n", + " # t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1]) # t1206: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1207 = prims.div(t1206, 4096.0) # t1207: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1208 = prims.add(t1207, 1e-05) # t1208: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1209 = prims.rsqrt(t1208) # t1209: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2)) # t1210: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1211 = prims.mul(t1201, t1210) # t1211: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1215 = prims.convert_element_type(t1213, dtypes.float32) # t1215: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1216 = prims.mul(t1211, t1215) # t1216: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1217 = prims.convert_element_type(t1216, dtypes.bfloat16) # t1217: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1218 = torch.nn.functional.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1218 = ltorch.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1218 = prims.linear(t1217, t13, None) # t1218: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128)) # t1219: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1218\n", + " t1220 = torch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4)) # t1220: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1219\n", + " (t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)\n", + " # (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)\n", + " # t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1221: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1222: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1223: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1220\n", + " t1224 = torch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1224 = ltorch.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1224 = prims.reshape(t1221, (1, 32, 512, 128)) # t1224: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1221\n", + " t1225 = torch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1225 = ltorch.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1225 = prims.reshape(t1222, (1, 32, 512, 128)) # t1225: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1222\n", + " t1226 = torch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1226 = ltorch.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1226 = prims.reshape(t1223, (1, 32, 512, 128)) # t1226: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1223\n", + " t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1227: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1242: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1257: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1224\n", + " t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1259: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1225\n", + " t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1228: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1229: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1243: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1244: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)\n", + " # t1230 = prims.convert_element_type(t1229, dtypes.float32) # t1230: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1231 = prims.neg(t1230) # t1231: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1232 = prims.convert_element_type(t1231, dtypes.bfloat16) # t1232: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1245 = prims.convert_element_type(t1244, dtypes.float32) # t1245: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1246 = prims.neg(t1245) # t1246: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1247 = prims.convert_element_type(t1246, dtypes.bfloat16) # t1247: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1229, t1244\n", + " t1233 = torch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1233 = ltorch.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1233 = prims.cat((t1232, t1228), -1) # t1233: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1232, t1228\n", + " t1248 = torch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1248 = ltorch.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1248 = prims.cat((t1247, t1243), -1) # t1248: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1247, t1243\n", + " [t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)\n", + " # t1235 = prims.convert_element_type(t1227, dtypes.float32) # t1235: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1250 = prims.convert_element_type(t1242, dtypes.float32) # t1250: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1236 = prims.mul(t1235, t154) # t1236: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1238 = prims.convert_element_type(t1233, dtypes.float32) # t1238: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1239 = prims.mul(t1238, t157) # t1239: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1240 = prims.add(t1236, t1239) # t1240: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1241 = prims.convert_element_type(t1240, dtypes.bfloat16) # t1241: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1251 = prims.mul(t1250, t154) # t1251: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1253 = prims.convert_element_type(t1248, dtypes.float32) # t1253: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1254 = prims.mul(t1253, t157) # t1254: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1255 = prims.add(t1251, t1254) # t1255: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1256 = prims.convert_element_type(t1255, dtypes.bfloat16) # t1256: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1227, t1233, t1242, t1248\n", + " t1258 = torch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1258 = ltorch.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1258 = prims.cat((t1241, t1257), -1) # t1258: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1241, t1257\n", + " t1260 = torch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1260 = ltorch.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1260 = prims.cat((t1256, t1259), -1) # t1260: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1256, t1259\n", + " (t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)\n", + " t1268 = torch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1268 = ltorch.permute(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1268 = prims.transpose(t1261, (0, 2, 1, 3)) # t1268: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1269 = torch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1269 = ltorch.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1269 = prims.reshape(t1268, (1, 512, 4096)) # t1269: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1268\n", + " t1270 = torch.nn.functional.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1270 = ltorch.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1270 = prims.linear(t1269, t105, None) # t1270: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)\n", + " # t1272 = prims.convert_element_type(t1202, dtypes.float32) # t1272: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1271 = prims.convert_element_type(t1270, dtypes.float32) # t1271: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1273 = prims.add(t1271, t1272) # t1273: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1274 = prims.convert_element_type(t1273, dtypes.bfloat16) # t1274: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1276 = prims.mul(t1273, t1273) # t1276: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1277 = prims.sum(t1276, (2,)) # t1277: \"cuda:0 f32[1, 512]\"\n", + " # t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1]) # t1278: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1279 = prims.div(t1278, 4096.0) # t1279: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1280 = prims.add(t1279, 1e-05) # t1280: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1281 = prims.rsqrt(t1280) # t1281: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2)) # t1282: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1283 = prims.mul(t1273, t1282) # t1283: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1287 = prims.convert_element_type(t1285, dtypes.float32) # t1287: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1288 = prims.mul(t1283, t1287) # t1288: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1289 = prims.convert_element_type(t1288, dtypes.bfloat16) # t1289: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1290 = torch.nn.functional.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1290 = ltorch.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1290 = prims.linear(t1289, t29, None) # t1290: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1291 = torch.nn.functional.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1291 = ltorch.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1291 = prims.linear(t1289, t45, None) # t1291: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1305] = nvFusion54(t1290, t1291)\n", + " # t1292 = prims.convert_element_type(t1290, dtypes.float32) # t1292: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1293 = prims.neg(t1292) # t1293: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1294 = prims.exp(t1293) # t1294: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1295 = prims.add(1.0, t1294) # t1295: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1296 = prims.reciprocal(t1295) # t1296: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1300 = prims.mul(t1292, t1296) # t1300: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1303 = prims.convert_element_type(t1291, dtypes.float32) # t1303: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1304 = prims.mul(t1300, t1303) # t1304: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1305 = prims.convert_element_type(t1304, dtypes.bfloat16) # t1305: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1306 = torch.nn.functional.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1306 = ltorch.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1306 = prims.linear(t1305, t106, None) # t1306: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)\n", + " # t1308 = prims.convert_element_type(t1274, dtypes.float32) # t1308: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1307 = prims.convert_element_type(t1306, dtypes.float32) # t1307: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1309 = prims.add(t1307, t1308) # t1309: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1310 = prims.convert_element_type(t1309, dtypes.bfloat16) # t1310: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1312 = prims.mul(t1309, t1309) # t1312: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1313 = prims.sum(t1312, (2,)) # t1313: \"cuda:0 f32[1, 512]\"\n", + " # t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1]) # t1314: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1315 = prims.div(t1314, 4096.0) # t1315: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1316 = prims.add(t1315, 1e-05) # t1316: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1317 = prims.rsqrt(t1316) # t1317: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2)) # t1318: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1319 = prims.mul(t1309, t1318) # t1319: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1323 = prims.convert_element_type(t1321, dtypes.float32) # t1323: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1324 = prims.mul(t1319, t1323) # t1324: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1325 = prims.convert_element_type(t1324, dtypes.bfloat16) # t1325: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1326 = torch.nn.functional.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1326 = ltorch.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1326 = prims.linear(t1325, t14, None) # t1326: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128)) # t1327: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1326\n", + " t1328 = torch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4)) # t1328: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1327\n", + " (t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)\n", + " # (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)\n", + " # t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1329: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1330: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1331: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1328\n", + " t1332 = torch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1332 = ltorch.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1332 = prims.reshape(t1329, (1, 32, 512, 128)) # t1332: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1329\n", + " t1333 = torch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1333 = ltorch.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1333 = prims.reshape(t1330, (1, 32, 512, 128)) # t1333: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1330\n", + " t1334 = torch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1334 = ltorch.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1334 = prims.reshape(t1331, (1, 32, 512, 128)) # t1334: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1331\n", + " t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1335: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1350: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1365: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1332\n", + " t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1367: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1333\n", + " t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1336: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1351: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1352: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)\n", + " # t1338 = prims.convert_element_type(t1337, dtypes.float32) # t1338: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1339 = prims.neg(t1338) # t1339: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1340 = prims.convert_element_type(t1339, dtypes.bfloat16) # t1340: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1353 = prims.convert_element_type(t1352, dtypes.float32) # t1353: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1354 = prims.neg(t1353) # t1354: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1355 = prims.convert_element_type(t1354, dtypes.bfloat16) # t1355: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1337, t1352\n", + " t1341 = torch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1341 = ltorch.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1341 = prims.cat((t1340, t1336), -1) # t1341: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1340, t1336\n", + " t1356 = torch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1356 = ltorch.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1356 = prims.cat((t1355, t1351), -1) # t1356: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1355, t1351\n", + " [t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)\n", + " # t1343 = prims.convert_element_type(t1335, dtypes.float32) # t1343: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1358 = prims.convert_element_type(t1350, dtypes.float32) # t1358: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1344 = prims.mul(t1343, t154) # t1344: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1346 = prims.convert_element_type(t1341, dtypes.float32) # t1346: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1347 = prims.mul(t1346, t157) # t1347: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1348 = prims.add(t1344, t1347) # t1348: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1349 = prims.convert_element_type(t1348, dtypes.bfloat16) # t1349: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1359 = prims.mul(t1358, t154) # t1359: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1361 = prims.convert_element_type(t1356, dtypes.float32) # t1361: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1362 = prims.mul(t1361, t157) # t1362: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1363 = prims.add(t1359, t1362) # t1363: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1364 = prims.convert_element_type(t1363, dtypes.bfloat16) # t1364: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1335, t1341, t1350, t1356\n", + " t1366 = torch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1366 = ltorch.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1366 = prims.cat((t1349, t1365), -1) # t1366: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1349, t1365\n", + " t1368 = torch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1368 = ltorch.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1368 = prims.cat((t1364, t1367), -1) # t1368: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1364, t1367\n", + " (t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)\n", + " t1376 = torch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1376 = ltorch.permute(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1376 = prims.transpose(t1369, (0, 2, 1, 3)) # t1376: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1377 = torch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1377 = ltorch.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1377 = prims.reshape(t1376, (1, 512, 4096)) # t1377: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1376\n", + " t1378 = torch.nn.functional.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1378 = ltorch.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1378 = prims.linear(t1377, t107, None) # t1378: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)\n", + " # t1380 = prims.convert_element_type(t1310, dtypes.float32) # t1380: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1379 = prims.convert_element_type(t1378, dtypes.float32) # t1379: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1381 = prims.add(t1379, t1380) # t1381: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1382 = prims.convert_element_type(t1381, dtypes.bfloat16) # t1382: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1384 = prims.mul(t1381, t1381) # t1384: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1385 = prims.sum(t1384, (2,)) # t1385: \"cuda:0 f32[1, 512]\"\n", + " # t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1]) # t1386: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1387 = prims.div(t1386, 4096.0) # t1387: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1388 = prims.add(t1387, 1e-05) # t1388: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1389 = prims.rsqrt(t1388) # t1389: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2)) # t1390: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1391 = prims.mul(t1381, t1390) # t1391: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1395 = prims.convert_element_type(t1393, dtypes.float32) # t1395: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1396 = prims.mul(t1391, t1395) # t1396: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1397 = prims.convert_element_type(t1396, dtypes.bfloat16) # t1397: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1398 = torch.nn.functional.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1398 = ltorch.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1398 = prims.linear(t1397, t30, None) # t1398: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1399 = torch.nn.functional.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1399 = ltorch.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1399 = prims.linear(t1397, t46, None) # t1399: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1413] = nvFusion59(t1398, t1399)\n", + " # t1400 = prims.convert_element_type(t1398, dtypes.float32) # t1400: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1401 = prims.neg(t1400) # t1401: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1402 = prims.exp(t1401) # t1402: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1403 = prims.add(1.0, t1402) # t1403: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1404 = prims.reciprocal(t1403) # t1404: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1408 = prims.mul(t1400, t1404) # t1408: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1411 = prims.convert_element_type(t1399, dtypes.float32) # t1411: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1412 = prims.mul(t1408, t1411) # t1412: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1413 = prims.convert_element_type(t1412, dtypes.bfloat16) # t1413: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1414 = torch.nn.functional.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1414 = ltorch.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1414 = prims.linear(t1413, t108, None) # t1414: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)\n", + " # t1416 = prims.convert_element_type(t1382, dtypes.float32) # t1416: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1415 = prims.convert_element_type(t1414, dtypes.float32) # t1415: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1417 = prims.add(t1415, t1416) # t1417: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1418 = prims.convert_element_type(t1417, dtypes.bfloat16) # t1418: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1420 = prims.mul(t1417, t1417) # t1420: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1421 = prims.sum(t1420, (2,)) # t1421: \"cuda:0 f32[1, 512]\"\n", + " # t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1]) # t1422: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1423 = prims.div(t1422, 4096.0) # t1423: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1424 = prims.add(t1423, 1e-05) # t1424: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1425 = prims.rsqrt(t1424) # t1425: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2)) # t1426: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1427 = prims.mul(t1417, t1426) # t1427: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1431 = prims.convert_element_type(t1429, dtypes.float32) # t1431: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1432 = prims.mul(t1427, t1431) # t1432: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1433 = prims.convert_element_type(t1432, dtypes.bfloat16) # t1433: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1434 = torch.nn.functional.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1434 = ltorch.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1434 = prims.linear(t1433, t15, None) # t1434: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128)) # t1435: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1434\n", + " t1436 = torch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4)) # t1436: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1435\n", + " (t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)\n", + " # (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)\n", + " # t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1437: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1438: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1439: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1436\n", + " t1440 = torch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1440 = ltorch.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1440 = prims.reshape(t1437, (1, 32, 512, 128)) # t1440: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1437\n", + " t1441 = torch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1441 = ltorch.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1441 = prims.reshape(t1438, (1, 32, 512, 128)) # t1441: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1438\n", + " t1442 = torch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1442 = ltorch.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1442 = prims.reshape(t1439, (1, 32, 512, 128)) # t1442: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1439\n", + " t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1443: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1458: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1473: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1440\n", + " t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1475: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1441\n", + " t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1444: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1445: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1459: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1460: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)\n", + " # t1446 = prims.convert_element_type(t1445, dtypes.float32) # t1446: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1447 = prims.neg(t1446) # t1447: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1448 = prims.convert_element_type(t1447, dtypes.bfloat16) # t1448: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1461 = prims.convert_element_type(t1460, dtypes.float32) # t1461: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1462 = prims.neg(t1461) # t1462: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1463 = prims.convert_element_type(t1462, dtypes.bfloat16) # t1463: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1445, t1460\n", + " t1464 = torch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1464 = ltorch.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1464 = prims.cat((t1463, t1459), -1) # t1464: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1463, t1459\n", + " t1449 = torch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1449 = ltorch.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1449 = prims.cat((t1448, t1444), -1) # t1449: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1448, t1444\n", + " [t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)\n", + " # t1451 = prims.convert_element_type(t1443, dtypes.float32) # t1451: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1466 = prims.convert_element_type(t1458, dtypes.float32) # t1466: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1467 = prims.mul(t1466, t154) # t1467: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1469 = prims.convert_element_type(t1464, dtypes.float32) # t1469: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1470 = prims.mul(t1469, t157) # t1470: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1471 = prims.add(t1467, t1470) # t1471: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1472 = prims.convert_element_type(t1471, dtypes.bfloat16) # t1472: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1452 = prims.mul(t1451, t154) # t1452: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1454 = prims.convert_element_type(t1449, dtypes.float32) # t1454: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1455 = prims.mul(t1454, t157) # t1455: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1456 = prims.add(t1452, t1455) # t1456: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1457 = prims.convert_element_type(t1456, dtypes.bfloat16) # t1457: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1443, t1449, t1458, t1464\n", + " t1476 = torch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1476 = ltorch.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1476 = prims.cat((t1472, t1475), -1) # t1476: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1472, t1475\n", + " t1474 = torch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1474 = ltorch.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1474 = prims.cat((t1457, t1473), -1) # t1474: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1457, t1473\n", + " (t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)\n", + " t1484 = torch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1484 = ltorch.permute(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1484 = prims.transpose(t1477, (0, 2, 1, 3)) # t1484: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1485 = torch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1485 = ltorch.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1485 = prims.reshape(t1484, (1, 512, 4096)) # t1485: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1484\n", + " t1486 = torch.nn.functional.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1486 = ltorch.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1486 = prims.linear(t1485, t109, None) # t1486: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)\n", + " # t1488 = prims.convert_element_type(t1418, dtypes.float32) # t1488: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1487 = prims.convert_element_type(t1486, dtypes.float32) # t1487: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1489 = prims.add(t1487, t1488) # t1489: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1490 = prims.convert_element_type(t1489, dtypes.bfloat16) # t1490: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1492 = prims.mul(t1489, t1489) # t1492: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1493 = prims.sum(t1492, (2,)) # t1493: \"cuda:0 f32[1, 512]\"\n", + " # t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1]) # t1494: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1495 = prims.div(t1494, 4096.0) # t1495: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1496 = prims.add(t1495, 1e-05) # t1496: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1497 = prims.rsqrt(t1496) # t1497: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2)) # t1498: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1499 = prims.mul(t1489, t1498) # t1499: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1503 = prims.convert_element_type(t1501, dtypes.float32) # t1503: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1504 = prims.mul(t1499, t1503) # t1504: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1505 = prims.convert_element_type(t1504, dtypes.bfloat16) # t1505: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1506 = torch.nn.functional.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1506 = ltorch.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1506 = prims.linear(t1505, t31, None) # t1506: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1507 = torch.nn.functional.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1507 = ltorch.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1507 = prims.linear(t1505, t47, None) # t1507: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1521] = nvFusion64(t1506, t1507)\n", + " # t1508 = prims.convert_element_type(t1506, dtypes.float32) # t1508: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1509 = prims.neg(t1508) # t1509: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1510 = prims.exp(t1509) # t1510: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1511 = prims.add(1.0, t1510) # t1511: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1512 = prims.reciprocal(t1511) # t1512: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1516 = prims.mul(t1508, t1512) # t1516: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1519 = prims.convert_element_type(t1507, dtypes.float32) # t1519: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1520 = prims.mul(t1516, t1519) # t1520: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1521 = prims.convert_element_type(t1520, dtypes.bfloat16) # t1521: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1522 = torch.nn.functional.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1522 = ltorch.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1522 = prims.linear(t1521, t110, None) # t1522: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)\n", + " # t1524 = prims.convert_element_type(t1490, dtypes.float32) # t1524: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1523 = prims.convert_element_type(t1522, dtypes.float32) # t1523: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1525 = prims.add(t1523, t1524) # t1525: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1526 = prims.convert_element_type(t1525, dtypes.bfloat16) # t1526: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1528 = prims.mul(t1525, t1525) # t1528: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1529 = prims.sum(t1528, (2,)) # t1529: \"cuda:0 f32[1, 512]\"\n", + " # t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1]) # t1530: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1531 = prims.div(t1530, 4096.0) # t1531: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1532 = prims.add(t1531, 1e-05) # t1532: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1533 = prims.rsqrt(t1532) # t1533: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2)) # t1534: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1535 = prims.mul(t1525, t1534) # t1535: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1539 = prims.convert_element_type(t1537, dtypes.float32) # t1539: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1540 = prims.mul(t1535, t1539) # t1540: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1541 = prims.convert_element_type(t1540, dtypes.bfloat16) # t1541: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1542 = torch.nn.functional.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1542 = ltorch.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1542 = prims.linear(t1541, t16, None) # t1542: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128)) # t1543: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1542\n", + " t1544 = torch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4)) # t1544: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1543\n", + " (t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)\n", + " # (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)\n", + " # t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1545: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1546: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1547: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1544\n", + " t1548 = torch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1548 = ltorch.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1548 = prims.reshape(t1545, (1, 32, 512, 128)) # t1548: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1545\n", + " t1549 = torch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1549 = ltorch.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1549 = prims.reshape(t1546, (1, 32, 512, 128)) # t1549: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1546\n", + " t1550 = torch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1550 = ltorch.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1550 = prims.reshape(t1547, (1, 32, 512, 128)) # t1550: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1547\n", + " t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1551: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1566: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1581: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1548\n", + " t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1583: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1549\n", + " t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1552: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1553: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1567: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1568: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)\n", + " # t1554 = prims.convert_element_type(t1553, dtypes.float32) # t1554: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1555 = prims.neg(t1554) # t1555: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1556 = prims.convert_element_type(t1555, dtypes.bfloat16) # t1556: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1569 = prims.convert_element_type(t1568, dtypes.float32) # t1569: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1570 = prims.neg(t1569) # t1570: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1571 = prims.convert_element_type(t1570, dtypes.bfloat16) # t1571: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1553, t1568\n", + " t1572 = torch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1572 = ltorch.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1572 = prims.cat((t1571, t1567), -1) # t1572: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1571, t1567\n", + " t1557 = torch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1557 = ltorch.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1557 = prims.cat((t1556, t1552), -1) # t1557: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1556, t1552\n", + " [t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)\n", + " # t1559 = prims.convert_element_type(t1551, dtypes.float32) # t1559: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1574 = prims.convert_element_type(t1566, dtypes.float32) # t1574: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1575 = prims.mul(t1574, t154) # t1575: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1577 = prims.convert_element_type(t1572, dtypes.float32) # t1577: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1578 = prims.mul(t1577, t157) # t1578: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1579 = prims.add(t1575, t1578) # t1579: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1580 = prims.convert_element_type(t1579, dtypes.bfloat16) # t1580: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1560 = prims.mul(t1559, t154) # t1560: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1562 = prims.convert_element_type(t1557, dtypes.float32) # t1562: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1563 = prims.mul(t1562, t157) # t1563: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1564 = prims.add(t1560, t1563) # t1564: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1551, t1557, t1566, t1572\n", + " t1584 = torch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1584 = ltorch.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1584 = prims.cat((t1580, t1583), -1) # t1584: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1580, t1583\n", + " t1582 = torch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1582 = ltorch.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1582 = prims.cat((t1565, t1581), -1) # t1582: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1565, t1581\n", + " (t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)\n", + " t1592 = torch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1592 = ltorch.permute(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1592 = prims.transpose(t1585, (0, 2, 1, 3)) # t1592: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1593 = torch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1593 = ltorch.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1592\n", + " t1594 = torch.nn.functional.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1594 = ltorch.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1594 = prims.linear(t1593, t111, None) # t1594: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)\n", + " # t1596 = prims.convert_element_type(t1526, dtypes.float32) # t1596: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1595 = prims.convert_element_type(t1594, dtypes.float32) # t1595: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1597 = prims.add(t1595, t1596) # t1597: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1598 = prims.convert_element_type(t1597, dtypes.bfloat16) # t1598: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1600 = prims.mul(t1597, t1597) # t1600: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1601 = prims.sum(t1600, (2,)) # t1601: \"cuda:0 f32[1, 512]\"\n", + " # t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1]) # t1602: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1603 = prims.div(t1602, 4096.0) # t1603: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1604 = prims.add(t1603, 1e-05) # t1604: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1605 = prims.rsqrt(t1604) # t1605: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2)) # t1606: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1607 = prims.mul(t1597, t1606) # t1607: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1611 = prims.convert_element_type(t1609, dtypes.float32) # t1611: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1612 = prims.mul(t1607, t1611) # t1612: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1614 = torch.nn.functional.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1614 = ltorch.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1614 = prims.linear(t1613, t32, None) # t1614: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1615 = torch.nn.functional.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1615 = ltorch.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1615 = prims.linear(t1613, t48, None) # t1615: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1629] = nvFusion69(t1614, t1615)\n", + " # t1616 = prims.convert_element_type(t1614, dtypes.float32) # t1616: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1617 = prims.neg(t1616) # t1617: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1618 = prims.exp(t1617) # t1618: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1619 = prims.add(1.0, t1618) # t1619: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1620 = prims.reciprocal(t1619) # t1620: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1624 = prims.mul(t1616, t1620) # t1624: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1627 = prims.convert_element_type(t1615, dtypes.float32) # t1627: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1628 = prims.mul(t1624, t1627) # t1628: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1629 = prims.convert_element_type(t1628, dtypes.bfloat16) # t1629: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1630 = torch.nn.functional.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1630 = ltorch.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1630 = prims.linear(t1629, t112, None) # t1630: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)\n", + " # t1632 = prims.convert_element_type(t1598, dtypes.float32) # t1632: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1631 = prims.convert_element_type(t1630, dtypes.float32) # t1631: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1633 = prims.add(t1631, t1632) # t1633: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1634 = prims.convert_element_type(t1633, dtypes.bfloat16) # t1634: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1636 = prims.mul(t1633, t1633) # t1636: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1637 = prims.sum(t1636, (2,)) # t1637: \"cuda:0 f32[1, 512]\"\n", + " # t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1]) # t1638: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1639 = prims.div(t1638, 4096.0) # t1639: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1640 = prims.add(t1639, 1e-05) # t1640: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1641 = prims.rsqrt(t1640) # t1641: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2)) # t1642: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1643 = prims.mul(t1633, t1642) # t1643: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1647 = prims.convert_element_type(t1645, dtypes.float32) # t1647: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1648 = prims.mul(t1643, t1647) # t1648: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1649 = prims.convert_element_type(t1648, dtypes.bfloat16) # t1649: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1650 = torch.nn.functional.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1650 = ltorch.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1650 = prims.linear(t1649, t17, None) # t1650: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128)) # t1651: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1650\n", + " t1652 = torch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4)) # t1652: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1651\n", + " (t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)\n", + " # (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)\n", + " # t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1653: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1654: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1655: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1652\n", + " t1656 = torch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1656 = ltorch.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1656 = prims.reshape(t1653, (1, 32, 512, 128)) # t1656: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1653\n", + " t1657 = torch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1657 = ltorch.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1657 = prims.reshape(t1654, (1, 32, 512, 128)) # t1657: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1654\n", + " t1658 = torch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1658 = ltorch.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1658 = prims.reshape(t1655, (1, 32, 512, 128)) # t1658: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1655\n", + " t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1689: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1691: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1659: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1656\n", + " t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1674: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1657\n", + " t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1660: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1661: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1675: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1676: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)\n", + " # t1662 = prims.convert_element_type(t1661, dtypes.float32) # t1662: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1663 = prims.neg(t1662) # t1663: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1664 = prims.convert_element_type(t1663, dtypes.bfloat16) # t1664: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1678 = prims.neg(t1677) # t1678: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1679 = prims.convert_element_type(t1678, dtypes.bfloat16) # t1679: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1661, t1676\n", + " t1680 = torch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1680 = ltorch.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1680 = prims.cat((t1679, t1675), -1) # t1680: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1679, t1675\n", + " t1665 = torch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1665 = ltorch.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1665 = prims.cat((t1664, t1660), -1) # t1665: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1664, t1660\n", + " [t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)\n", + " # t1667 = prims.convert_element_type(t1659, dtypes.float32) # t1667: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1682 = prims.convert_element_type(t1674, dtypes.float32) # t1682: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1683 = prims.mul(t1682, t154) # t1683: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1685 = prims.convert_element_type(t1680, dtypes.float32) # t1685: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1686 = prims.mul(t1685, t157) # t1686: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1687 = prims.add(t1683, t1686) # t1687: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1688 = prims.convert_element_type(t1687, dtypes.bfloat16) # t1688: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1668 = prims.mul(t1667, t154) # t1668: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1670 = prims.convert_element_type(t1665, dtypes.float32) # t1670: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1671 = prims.mul(t1670, t157) # t1671: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1672 = prims.add(t1668, t1671) # t1672: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1673 = prims.convert_element_type(t1672, dtypes.bfloat16) # t1673: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1659, t1665, t1674, t1680\n", + " t1692 = torch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1692 = ltorch.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1692 = prims.cat((t1688, t1691), -1) # t1692: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1688, t1691\n", + " t1690 = torch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1690 = ltorch.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1690 = prims.cat((t1673, t1689), -1) # t1690: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1673, t1689\n", + " (t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)\n", + " t1700 = torch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1700 = ltorch.permute(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1700 = prims.transpose(t1693, (0, 2, 1, 3)) # t1700: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1701 = torch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1701 = ltorch.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1701 = prims.reshape(t1700, (1, 512, 4096)) # t1701: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1700\n", + " t1702 = torch.nn.functional.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1702 = ltorch.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1702 = prims.linear(t1701, t113, None) # t1702: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)\n", + " # t1704 = prims.convert_element_type(t1634, dtypes.float32) # t1704: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1703 = prims.convert_element_type(t1702, dtypes.float32) # t1703: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1705 = prims.add(t1703, t1704) # t1705: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1708 = prims.mul(t1705, t1705) # t1708: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1709 = prims.sum(t1708, (2,)) # t1709: \"cuda:0 f32[1, 512]\"\n", + " # t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1]) # t1710: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1711 = prims.div(t1710, 4096.0) # t1711: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1712 = prims.add(t1711, 1e-05) # t1712: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1713 = prims.rsqrt(t1712) # t1713: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2)) # t1714: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1715 = prims.mul(t1705, t1714) # t1715: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1719 = prims.convert_element_type(t1717, dtypes.float32) # t1719: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1720 = prims.mul(t1715, t1719) # t1720: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1721 = prims.convert_element_type(t1720, dtypes.bfloat16) # t1721: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1722 = torch.nn.functional.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1722 = ltorch.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1722 = prims.linear(t1721, t33, None) # t1722: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1723 = torch.nn.functional.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1723 = ltorch.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1723 = prims.linear(t1721, t49, None) # t1723: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1737] = nvFusion74(t1722, t1723)\n", + " # t1724 = prims.convert_element_type(t1722, dtypes.float32) # t1724: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1725 = prims.neg(t1724) # t1725: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1726 = prims.exp(t1725) # t1726: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1727 = prims.add(1.0, t1726) # t1727: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1728 = prims.reciprocal(t1727) # t1728: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1732 = prims.mul(t1724, t1728) # t1732: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1735 = prims.convert_element_type(t1723, dtypes.float32) # t1735: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1736 = prims.mul(t1732, t1735) # t1736: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1738 = torch.nn.functional.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1738 = ltorch.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1738 = prims.linear(t1737, t114, None) # t1738: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)\n", + " # t1740 = prims.convert_element_type(t1706, dtypes.float32) # t1740: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1739 = prims.convert_element_type(t1738, dtypes.float32) # t1739: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1741 = prims.add(t1739, t1740) # t1741: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1742 = prims.convert_element_type(t1741, dtypes.bfloat16) # t1742: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1744 = prims.mul(t1741, t1741) # t1744: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1745 = prims.sum(t1744, (2,)) # t1745: \"cuda:0 f32[1, 512]\"\n", + " # t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1]) # t1746: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1747 = prims.div(t1746, 4096.0) # t1747: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1748 = prims.add(t1747, 1e-05) # t1748: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1749 = prims.rsqrt(t1748) # t1749: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2)) # t1750: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1751 = prims.mul(t1741, t1750) # t1751: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1755 = prims.convert_element_type(t1753, dtypes.float32) # t1755: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1756 = prims.mul(t1751, t1755) # t1756: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1757 = prims.convert_element_type(t1756, dtypes.bfloat16) # t1757: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1758 = torch.nn.functional.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1758 = ltorch.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " # t1758 = prims.linear(t1757, t18, None) # t1758: \"cuda:0 bf16[1, 512, 12288]\"\n", + " t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " # t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128)) # t1759: \"cuda:0 bf16[1, 512, 32, 3, 128]\"\n", + " del t1758\n", + " t1760 = torch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " # t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4)) # t1760: \"cuda:0 bf16[1, 32, 3, 512, 128]\"\n", + " del t1759\n", + " (t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)\n", + " # (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)\n", + " # t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1761: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1762: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " # t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1763: \"cuda:0 bf16[1, 32, 1, 512, 128]\"\n", + " del t1760\n", + " t1764 = torch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1764 = ltorch.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1764 = prims.reshape(t1761, (1, 32, 512, 128)) # t1764: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1761\n", + " t1765 = torch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1765 = ltorch.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1765 = prims.reshape(t1762, (1, 32, 512, 128)) # t1765: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1762\n", + " t1766 = torch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1766 = ltorch.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1766 = prims.reshape(t1763, (1, 32, 512, 128)) # t1766: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1763\n", + " t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1767: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1782: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1797: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1764\n", + " t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1799: \"cuda:0 bf16[1, 32, 512, 0]\"\n", + " del t1765\n", + " t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1768: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1769: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1783: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1784: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " [t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)\n", + " # t1770 = prims.convert_element_type(t1769, dtypes.float32) # t1770: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1771 = prims.neg(t1770) # t1771: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1772 = prims.convert_element_type(t1771, dtypes.bfloat16) # t1772: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " # t1785 = prims.convert_element_type(t1784, dtypes.float32) # t1785: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1786 = prims.neg(t1785) # t1786: \"cuda:0 f32[1, 32, 512, 64]\"\n", + " # t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: \"cuda:0 bf16[1, 32, 512, 64]\"\n", + " del t1769, t1784\n", + " t1788 = torch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1788 = ltorch.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1788 = prims.cat((t1787, t1783), -1) # t1788: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1787, t1783\n", + " t1773 = torch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1773 = ltorch.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1773 = prims.cat((t1772, t1768), -1) # t1773: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1772, t1768\n", + " [t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)\n", + " # t1775 = prims.convert_element_type(t1767, dtypes.float32) # t1775: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1790 = prims.convert_element_type(t1782, dtypes.float32) # t1790: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1791 = prims.mul(t1790, t154) # t1791: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1793 = prims.convert_element_type(t1788, dtypes.float32) # t1793: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1794 = prims.mul(t1793, t157) # t1794: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1795 = prims.add(t1791, t1794) # t1795: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1796 = prims.convert_element_type(t1795, dtypes.bfloat16) # t1796: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1776 = prims.mul(t1775, t154) # t1776: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1778 = prims.convert_element_type(t1773, dtypes.float32) # t1778: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1779 = prims.mul(t1778, t157) # t1779: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1780 = prims.add(t1776, t1779) # t1780: \"cuda:0 f32[1, 32, 512, 128]\"\n", + " # t1781 = prims.convert_element_type(t1780, dtypes.bfloat16) # t1781: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1767, t1773, t1782, t1788\n", + " t1800 = torch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1800 = ltorch.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1800 = prims.cat((t1796, t1799), -1) # t1800: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1796, t1799\n", + " t1798 = torch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1798 = ltorch.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " # t1798 = prims.cat((t1781, t1797), -1) # t1798: \"cuda:0 bf16[1, 32, 512, 128]\"\n", + " del t1781, t1797\n", + " (t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)\n", + " t1808 = torch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1808 = ltorch.permute(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " # t1808 = prims.transpose(t1801, (0, 2, 1, 3)) # t1808: \"cuda:0 bf16[1, 512, 32, 128]\"\n", + " t1809 = torch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1809 = ltorch.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1809 = prims.reshape(t1808, (1, 512, 4096)) # t1809: \"cuda:0 bf16[1, 512, 4096]\"\n", + " del t1808\n", + " t1810 = torch.nn.functional.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1810 = ltorch.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1810 = prims.linear(t1809, t115, None) # t1810: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)\n", + " # t1812 = prims.convert_element_type(t1742, dtypes.float32) # t1812: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1811 = prims.convert_element_type(t1810, dtypes.float32) # t1811: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1813 = prims.add(t1811, t1812) # t1813: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1814 = prims.convert_element_type(t1813, dtypes.bfloat16) # t1814: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1816 = prims.mul(t1813, t1813) # t1816: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1817 = prims.sum(t1816, (2,)) # t1817: \"cuda:0 f32[1, 512]\"\n", + " # t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1]) # t1818: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1819 = prims.div(t1818, 4096.0) # t1819: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1820 = prims.add(t1819, 1e-05) # t1820: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1821 = prims.rsqrt(t1820) # t1821: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2)) # t1822: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1823 = prims.mul(t1813, t1822) # t1823: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1827 = prims.convert_element_type(t1825, dtypes.float32) # t1827: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1828 = prims.mul(t1823, t1827) # t1828: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1829 = prims.convert_element_type(t1828, dtypes.bfloat16) # t1829: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1831 = torch.nn.functional.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1831 = ltorch.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1831 = prims.linear(t1829, t50, None) # t1831: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1830 = torch.nn.functional.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1830 = ltorch.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " # t1830 = prims.linear(t1829, t34, None) # t1830: \"cuda:0 bf16[1, 512, 11008]\"\n", + " [t1845] = nvFusion79(t1830, t1831)\n", + " # t1832 = prims.convert_element_type(t1830, dtypes.float32) # t1832: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1833 = prims.neg(t1832) # t1833: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1834 = prims.exp(t1833) # t1834: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1835 = prims.add(1.0, t1834) # t1835: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1836 = prims.reciprocal(t1835) # t1836: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1840 = prims.mul(t1832, t1836) # t1840: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1843 = prims.convert_element_type(t1831, dtypes.float32) # t1843: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1844 = prims.mul(t1840, t1843) # t1844: \"cuda:0 f32[1, 512, 11008]\"\n", + " # t1845 = prims.convert_element_type(t1844, dtypes.bfloat16) # t1845: \"cuda:0 bf16[1, 512, 11008]\"\n", + " t1846 = torch.nn.functional.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1846 = ltorch.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " # t1846 = prims.linear(t1845, t116, None) # t1846: \"cuda:0 bf16[1, 512, 4096]\"\n", + " [t1857, t1865] = nvFusion80(t1814, t1846, t1861)\n", + " # t1848 = prims.convert_element_type(t1814, dtypes.float32) # t1848: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1847 = prims.convert_element_type(t1846, dtypes.float32) # t1847: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1849 = prims.add(t1847, t1848) # t1849: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1852 = prims.mul(t1849, t1849) # t1852: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1853 = prims.sum(t1852, (2,)) # t1853: \"cuda:0 f32[1, 512]\"\n", + " # t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1]) # t1854: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1855 = prims.div(t1854, 4096.0) # t1855: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1856 = prims.add(t1855, 1e-05) # t1856: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1857 = prims.rsqrt(t1856) # t1857: \"cuda:0 f32[1, 512, 1]\"\n", + " # t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2)) # t1858: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1859 = prims.mul(t1849, t1858) # t1859: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1863 = prims.convert_element_type(t1861, dtypes.float32) # t1863: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1864 = prims.mul(t1859, t1863) # t1864: \"cuda:0 f32[1, 512, 4096]\"\n", + " # t1865 = prims.convert_element_type(t1864, dtypes.bfloat16) # t1865: \"cuda:0 bf16[1, 512, 4096]\"\n", + " t1866 = torch.nn.functional.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " # t1866 = ltorch.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " # t1866 = prims.linear(t1865, t51, None) # t1866: \"cuda:0 bf16[1, 512, 32000]\"\n", + " return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))" ] }, "execution_count": 9, @@ -1016,13 +3161,18 @@ } ], "source": [ + "print(actual.grad_fn)\n", "thunder.last_traces(thunder_model)[-1]" ] }, { "cell_type": "markdown", - "id": "4944f352", - "metadata": {}, + "id": "558f2553-37f7-4b58-b7cd-a744155613a8", + "metadata": { + "slideshow": { + "slide_type": "notes" + } + }, "source": [ "Well, that is quite a bit to look through.\n", "But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with \n", @@ -1032,19 +3182,19 @@ { "cell_type": "code", "execution_count": 10, - "id": "4d90df65", + "id": "59643398-d6e2-4c32-81bd-145a1198b1f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[-0.9922, 0.5946, -0.2173, ..., -0.0981, -0.5058, 0.2747],\n", - " [-1.1552, 0.5770, -0.7432, ..., 0.0688, 0.1238, 0.6786],\n", - " [-0.7813, 0.6960, 0.1235, ..., -0.4840, 0.1373, 0.6490],\n", + "tensor([[[ 0.4160, -0.4668, 1.1016, ..., 0.5430, 1.2656, 0.2891],\n", + " [ 0.3320, -0.0557, 1.7891, ..., 1.0703, 1.0078, 1.2266],\n", + " [ 0.6836, -0.2871, 0.9531, ..., 0.0806, 0.7070, 0.8477],\n", " ...,\n", - " [ 0.3711, 0.1656, 0.3350, ..., -0.0294, 0.3670, 0.5099],\n", - " [-0.2544, -0.8470, 0.2063, ..., -0.1341, 0.1877, 0.2612],\n", - " [ 0.3420, -1.1421, 0.9222, ..., 0.5636, 0.1666, 0.6947]]],\n", + " [ 0.7695, -0.1260, 0.7266, ..., 0.1118, -0.0238, -1.2656],\n", + " [-0.7773, -0.5547, -0.3047, ..., -0.1807, 0.1895, 0.6875],\n", + " [ 0.8867, 0.4766, 0.3984, ..., 0.0815, -0.0879, 0.3477]]],\n", " device='cuda:0', grad_fn=)" ] }, @@ -1059,10 +3209,10 @@ }, { "cell_type": "markdown", - "id": "7dcec40f", + "id": "17341d86-d4c9-46bd-ac5e-3a05da1ff72c", "metadata": {}, "source": [ - "One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced." + "Let us clean up a bit." ] }, { @@ -1070,25 +3220,21 @@ "execution_count": 11, "id": "6ba7f715", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "maximum deviation grads: 0.00042724609375\n" - ] - } - ], + "outputs": [], "source": [ - "actual_grads = torch.autograd.grad(actual.sum(), m.parameters())\n", - "expected_grads = torch.autograd.grad(expected.sum(), m.parameters())\n", - "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))" + "del actual, expected\n", + "import gc\n", + "gc.collect();" ] }, { "cell_type": "markdown", "id": "0261eb11", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "But is it faster? Yes!" ] @@ -1096,50 +3242,52 @@ { "cell_type": "code", "execution_count": 12, - "id": "854f29a5", + "id": "bccec79b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "154 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n", - "150 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", + "208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ - "import gc\n", - "gc.collect()\n", "%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()\n", "%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()" ] }, + { + "cell_type": "markdown", + "id": "1d31e7f8", + "metadata": {}, + "source": [ + "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!\n" + ] + }, { "cell_type": "code", "execution_count": 13, - "id": "eb177aad", + "id": "ecad9125-bbf2-42c8-b11c-23eed4a6cd8f", "metadata": {}, "outputs": [], "source": [ "del m, thunder_model\n", "import gc\n", "gc.collect()\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "id": "1d31e7f8", - "metadata": {}, - "source": [ - "So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!" + "torch.cuda.empty_cache()\n" ] }, { "cell_type": "markdown", - "id": "d23ebbf5", - "metadata": {}, + "id": "49e3273c-99be-4370-9e59-121c00481b4e", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Distributed with Thunder\n", "\n", @@ -1162,7 +3310,11 @@ "cell_type": "code", "execution_count": 14, "id": "18dd3379", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "outputs": [ { "name": "stdout", @@ -1174,21 +3326,19 @@ ], "source": [ "%%writefile zero_to_thunder_fsdp_simple_example.py\n", - "import sys\n", - "sys.path.insert(0, '..')\n", "from thunder.tests.lit_gpt_model import GPT, Config\n", - "\n", - "import torch\n", - "import torch.distributed\n", - "import thunder\n", - "import thunder.distributed\n", "import os\n", + "import torch, torch.distributed\n", + "import thunder, thunder.distributed\n", "\n", "# Create Model\n", "# NOTE: We create the model on CPU.\n", "device='cpu'\n", "torch.set_default_dtype(torch.bfloat16)\n", - "model = GPT.from_name('llama2-like')\n", + "cfg = Config.from_name('Llama-2-7b-hf')\n", + "cfg.n_layer = 8 # fewer layers\n", + "model = GPT(cfg)\n", + "\n", "# Setup for distributed\n", "torch.distributed.init_process_group(backend='nccl')\n", "rank = int(os.environ[\"LOCAL_RANK\"])\n", @@ -1199,13 +3349,19 @@ "# thunder.distributed.fsdp takes care of moving the parameter\n", "# shard to the correct GPU for the current process.\n", "model = thunder.jit(thunder.distributed.fsdp(model)) # <---------------------------------------\n", - "\n", + "print(f\"rank {rank} computing\")\n", "# Run the forward pass.\n", - "res = model(x)\n", - "res.sum().backward()\n", - "\n", - "res = model(x)\n", - "res.sum().backward()\n" + "for i in range(10):\n", + " res = model(x)\n", + " res.sum().backward()\n" + ] + }, + { + "cell_type": "markdown", + "id": "97e8edbf-424d-49a7-8ed6-12cb5e5d65fc", + "metadata": {}, + "source": [ + "Now we can launch it. Note that you need two GPUs for this to run correctly." ] }, { @@ -1213,17 +3369,22 @@ "execution_count": 15, "id": "2bad9b64", "metadata": { - "scrolled": false + "scrolled": true, + "slideshow": { + "slide_type": "skip" + } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] \r\n", - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n", - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \r\n", - "W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************\r\n" + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] \n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************\n", + "rank 1 computing\n", + "rank 0 computing\n" ] } ], @@ -1234,21 +3395,29 @@ { "cell_type": "markdown", "id": "9c65e75d", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "source": [ - "So there. FSDP with just wrapping the model in `fsdp`." + "So there. FSDP with just wrapping the model in `fsdp`.\n" ] }, { "cell_type": "markdown", "id": "4a6d7a20", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "## Extending Thunder\n", "\n", "But we promised that thunder is extensible. Let's find out what's up with that.\n", "\n", - "Specifically, we will incorporate the RMSNorm kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", + "Specifically, we will incorporate the fast rope embedding kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).\n", "\n", "In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one." ] @@ -1277,91 +3446,94 @@ }, { "cell_type": "markdown", - "id": "a63595ab", - "metadata": {}, + "id": "2fe3b40b-c6e9-417c-ab7a-32606cee871a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "source": [ - "For our base implementation, we take the code from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)\n", + "For our base implementation, we take the code from [LitGPT's implementation](https://github.com/Lightning-AI/litgpt/blob/be6139e1fd4b240d253efd58124457496d23d173/litgpt/model.py#L355-L361)\n", "\n", - "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n" + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.\n", + "Because we will demonstrate Thunder's ability to divert functions in the model, we make a version here that will not be diverted." ] }, { "cell_type": "code", "execution_count": 17, - "id": "247074b3", - "metadata": {}, + "id": "3e74436b-d8eb-472b-9d6d-b6412378fde7", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "outputs": [], "source": [ - "from thunder import TensorProxy\n", - "\n", - "# Taken from LitGPT, who in turn credit:\n", - "# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n", - "# https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n", - "\n", - "def rms_norm_impl(x: torch.Tensor, weight, dim: int, eps: float, add_unit_offset: bool) -> torch.Tensor:\n", - " dtype = x.dtype\n", - " x = x.float()\n", - " # NOTE: the original RMSNorm paper implementation is not equivalent\n", - " norm_x = torch.mean(x * x, dim=dim, keepdim=True)\n", - " x_normed = x * torch.rsqrt(norm_x + eps)\n", - " x_normed = x_normed.to(dtype=dtype)\n", - " if add_unit_offset:\n", - " # Gemma model requires a unit offset\n", - " # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176\n", - " return x_normed * (1 + weight)\n", - " return x_normed * weight\n", - "\n", - "def rms_norm_meta(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool) -> TensorProxy:\n", - " return TensorProxy(like=x)\n", - "\n", - "rms_norm = my_ex.register_operator('rms_norm', meta=rms_norm_meta, fn=rms_norm_impl)\n" + "import lit_gpt\n", + "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", + " head_size = x.size(-1)\n", + " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n", + " x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)\n", + " rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)\n", + " roped = (x * cos) + (rotated * sin)\n", + " return roped.to(dtype=x.dtype)" ] }, { "cell_type": "markdown", - "id": "75ad1dbf", - "metadata": {}, + "id": "a63595ab", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, "source": [ - "For this short demo, we monkey-patch LitGPT to replace its own implementation. For your own model, you might start out with a that in your code directly." + "### Registering operators\n", + "\n", + "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n", + "\n", + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n" ] }, { "cell_type": "code", "execution_count": 18, - "id": "e0bdecd3", + "id": "247074b3", "metadata": {}, "outputs": [], "source": [ - "import lit_gpt.rmsnorm\n", - "if not hasattr(lit_gpt.rmsnorm, 'ThunderOrigRMSNorm'):\n", - " lit_gpt.rmsnorm.ThunderOrigRMSNorm = lit_gpt.rmsnorm.RMSNorm\n", + "import torch, thunder\n", + "from thunder.tests.lit_gpt_model import GPT\n", + "from thunder import TensorProxy\n", "\n", - "class ThunderizedRMSNorm(lit_gpt.rmsnorm.ThunderOrigRMSNorm):\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " # This isn't the best paradigm. :/\n", - " if thunder.core.interpreter.is_jitting():\n", - " return rms_norm(x, self.weight, self.dim, self.eps, self.add_unit_offset)\n", - " else:\n", - " return super().forward(x)\n", + "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", + " return lit_gpt.model.apply_rope(x, cos, sin)\n", + "\n", + "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", + " return TensorProxy(like=x)\n", "\n", - "lit_gpt.rmsnorm.RMSNorm = ThunderizedRMSNorm" + "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n", + " replaces=lit_gpt.model.apply_rope)" ] }, { "cell_type": "markdown", "id": "d6b7d056", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "We can try our new RMSNorm: " + "### Testing our new operator " ] }, { "cell_type": "code", "execution_count": 19, "id": "0ebd5dd1", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1379,12 +3551,13 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def computation(x, t_weight):\n", - " # x: \"cuda:0 f32[256, 4096]\" \n", - " # t_weight: \"cuda:0 f32[4096]\" \n", - " t7 = rms_norm(x, t_weight, -1, 1e-06, False) # t7: \"cuda:0 f32[256, 4096]\"\n", - " del x, t_weight\n", - " return t7" + "def computation(x, t_1_cos, t_1_sin):\n", + " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", + " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", + " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", + " t2 = apply_rope(x, t_1_cos, t_1_sin) # t2: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", + " del x, t_1_cos, t_1_sin\n", + " return t2" ] }, "execution_count": 19, @@ -1393,37 +3566,37 @@ } ], "source": [ - "with torch.device('cuda'):\n", - " norm_module = ThunderizedRMSNorm(4096)\n", - " x = torch.randn(256, 4096)\n", + "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n", "\n", - "# we're not quite there to handle forward and backward yet, we'll re-enable them below\n", - "for p in norm_module.parameters(): \n", - " p.requires_grad_(False)\n", + "def test_apply_rope(x, m):\n", + " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n", "\n", - "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", - "expected = norm_module(x)\n", - "actual = thunder_norm_module(x)\n", + "expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print(\"deviation:\", (expected - actual).abs().max().item())\n", "\n", - "print(\"deviation:\", (expected - actual).abs().max().item())\n", - "\n", - "thunder.last_traces(thunder_norm_module)[-1]" + "thunder.last_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", "id": "8c620a38", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ + "### Optimized kernels\n", + "\n", "But why did we do this? Well, we can now layer a faster implementation on top.\n", - "For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas." + "For this we take the [unsloth fast rope embedding](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rope_embedding.py) kernels. We take the bits that were in the forward and backward of the `autograd.Function` into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels." ] }, { "cell_type": "code", "execution_count": 20, - "id": "a7a26f5f", + "id": "6e6d0b1e-ba14-43e5-b0d9-27c0e3b46879", "metadata": {}, "outputs": [], "source": [ @@ -1459,196 +3632,214 @@ " elif BLOCK_SIZE >= 2048: num_warps = 8\n", " return BLOCK_SIZE, num_warps\n", "\n", + "@triton.heuristics({\"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],})\n", "@triton.jit\n", - "def _rms_layernorm_forward(\n", - " Y, Y_row_stride,\n", - " X, X_row_stride,\n", - " W, W_row_stride,\n", - " r, r_row_stride,\n", - " n_cols, eps,\n", - " BLOCK_SIZE : tl.constexpr\n", + "def _rope_embedding(\n", + " Q, Q_row_stride,\n", + " cos, cos_row_stride,\n", + " sin, sin_row_stride,\n", + " seqlen, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS: tl.constexpr,\n", + " BLOCK_SIZE : tl.constexpr,\n", "):\n", " \"\"\"\n", - " Fast RMS Layernorm kernel\n", - " Inspiration from a Triton tutorial:\n", - " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n", + " Calculates the RoPE Embedding quickly\n", + " RoPE is Q * cos + rotate_half(Q) * sin\n", + " See our blog post for more info\n", " \"\"\"\n", - " row_idx = tl.program_id(0)\n", - " col_offsets = tl.arange(0, BLOCK_SIZE)\n", - " mask = col_offsets < n_cols\n", + " row_position = tl.program_id(0)\n", + " group_head_position = tl.program_id(1)\n", + " col_offsets = tl.arange(0, BLOCK_SIZE)\n", + " half_head_dim = head_dim // 2\n", + " mask = col_offsets < half_head_dim\n", "\n", - " Y += row_idx * Y_row_stride\n", - " X += row_idx * X_row_stride\n", - " r += row_idx * r_row_stride\n", + " sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n", + " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", + " cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n", + " half_head_dim*0 + col_offsets, mask = mask, other = 0)\n", "\n", - " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - " W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)\n", + " if BACKWARD_PASS:\n", + " # See our blog post for more info.\n", + " sin1 = -sin1\n", + " pass\n", "\n", - " row_var = tl.sum(X_row * X_row, axis = 0) / n_cols\n", - " inv_var = tl.math.rsqrt(row_var + eps)\n", - " tl.store(r, inv_var)\n", - " normed = X_row * inv_var\n", - " normed = normed.to(W_row.dtype) # Exact copy from HF\n", - " output = normed * W_row\n", - " tl.store(Y + col_offsets, output, mask = mask)\n", + " head_start = group_head_position * group_size\n", + " head_end = min((head_start + group_size), n_heads)\n", "\n", + " for i in range(head_start, head_end):\n", + " offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets\n", + " offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim\n", "\n", - "@triton.jit\n", - "def _rms_layernorm_backward(\n", - " dY, dY_row_stride,\n", - " X, X_row_stride,\n", - " W, W_row_stride,\n", - " r, r_row_stride,\n", - " dW, dW_row_stride,\n", - " n_cols, eps,\n", - " BLOCK_SIZE : tl.constexpr,\n", - "):\n", - " \"\"\"\n", - " Fast RMS Layernorm kernel for the backward pass\n", - " Inspiration from a Triton tutorial:\n", - " https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n", - " \"\"\"\n", - " row_idx = tl.program_id(0)\n", - " col_offsets = tl.arange(0, BLOCK_SIZE)\n", - " mask = col_offsets < n_cols\n", - "\n", - " dY += row_idx * dY_row_stride\n", - " X += row_idx * X_row_stride\n", - " r += row_idx * r_row_stride\n", - "\n", - " dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - " X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - " W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n", - "\n", - " # Get saved row variance\n", - " inv_var = tl.load(r).to(tl.float32)\n", - " normed = X_row * inv_var\n", - "\n", - " dY_W = dY_row * W_row\n", - "\n", - " rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)\n", - " output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)\n", - " tl.store(dY + col_offsets, output, mask = mask)\n", - " \n", - "def rms_layernorm_forward_impl(X, W, eps):\n", - " shape = X.shape\n", - " dim = shape[-1]\n", - " X = X.view(-1, dim)\n", - " n_rows, n_cols = X.shape\n", - " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n", - "\n", - " Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda\")\n", - " r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda\")\n", - "\n", - " _rms_layernorm_forward[(n_rows,)](\n", - " Y, Y.stride(0),\n", - " X, X.stride(0),\n", - " W, W.stride(0),\n", - " r, r.stride(0),\n", - " n_cols, eps,\n", + " # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n", + " Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n", + " Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n", + "\n", + " tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n", + " tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n", + " pass\n", + "pass\n", + "\n", + "\n", + "def fast_rope_embedding_forward(Q, cos, sin):\n", + " Q = Q.transpose(1, 2).clone()\n", + " cos, sin = cos.squeeze(), sin.squeeze()\n", + " batch, seq_len, n_heads, head_dim = Q.shape\n", + " Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n", + " n_rows, n_cols = Q.shape\n", + " assert(seq_len <= cos.shape[0])\n", + "\n", + " # [TODO] Changing blocksize to head_dim//2 seems to have\n", + " # some concurrency / un-deterministic issues.\n", + " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n", + " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", + " n_groups = triton.cdiv(n_heads, group_size)\n", + "\n", + " grid = (n_rows, n_groups, )\n", + " _rope_embedding[grid](\n", + " Q, Q.stride(0),\n", + " cos, cos.stride(0),\n", + " sin, sin.stride(0),\n", + " seq_len, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS = False,\n", " BLOCK_SIZE = BLOCK_SIZE,\n", " num_warps = num_warps,\n", " )\n", - " return Y.view(*shape), (r, BLOCK_SIZE, num_warps)\n", - "\n", - "def rms_layernorm_forward_meta(X, W, eps):\n", - " n_cols = X.shape[-1]\n", - " n_rows = 1\n", - " for i in X.shape[:-1]:\n", - " n_rows *= i\n", - " BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n", - " Y = TensorProxy(like=X, requires_grad=True)\n", - " return (Y,\n", - " (TensorProxy(shape=(n_rows,), device=X.device, dtype=thunder.dtypes.float32, requires_grad=False),\n", - " BLOCK_SIZE, \n", - " num_warps,\n", - " )\n", - " )\n", - "\n", - "def rms_layernorm_backward_impl(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n", - " shape = dY.shape\n", - " dim = shape[-1]\n", - " dY = dY.view(-1, dim)\n", + " Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)\n", + " return Q, (BLOCK_SIZE, num_warps) \n", + "\n", + "def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):\n", + " dY = dY.transpose(1, 2)\n", + " batch, seq_len, n_heads, head_dim = dY.shape\n", + " dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n", + " # Must be reshape not view\n", " n_rows, n_cols = dY.shape\n", - " dW = X\n", - " dX = dY.clone()\n", - " _rms_layernorm_backward[(n_rows,)](\n", - " dX, dX.stride(0),\n", - " X, X .stride(0),\n", - " W, W .stride(0),\n", - " r, r .stride(0),\n", - " dW, dW.stride(0),\n", - " n_cols, eps,\n", + "\n", + " group_size = 4 # 4 or 8, too large group_size can hurt performance.\n", + " n_groups = triton.cdiv(n_heads, group_size)\n", + "\n", + " grid = (n_rows, n_groups, )\n", + " _rope_embedding[grid](\n", + " dY, dY .stride(0),\n", + " cos, cos.stride(0),\n", + " sin, sin.stride(0),\n", + " seq_len, head_dim, group_size, n_heads,\n", + " BACKWARD_PASS = True,\n", " BLOCK_SIZE = BLOCK_SIZE,\n", " num_warps = num_warps,\n", " )\n", - " dX = dX.view(*shape)\n", - " return dX\n", + " dY = dY.view(batch, seq_len, n_heads, head_dim)\n", + " dY = dY.transpose(1, 2) \n", + " return dY\n" + ] + }, + { + "cell_type": "markdown", + "id": "ed1e9be3-d1c9-4c4b-bf14-a025a03687ac", + "metadata": {}, + "source": [ + "We also define the corresponding meta functions." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d7e6612d-f1fc-497c-9d64-15ef99824086", + "metadata": {}, + "outputs": [], + "source": [ + "def fast_rope_embedding_forward_meta(Q, cos, sin):\n", + " batch, n_heads, seq_len, head_dim = Q.shape\n", + " n_rows, n_cols = batch*seq_len, n_heads*head_dim \n", + " assert(seq_len <= cos.shape[0])\n", + "\n", + " BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)\n", + " return TensorProxy(like=Q), (BLOCK_SIZE, num_warps) \n", "\n", - "def rms_layernorm_backward_meta(X, W, r, eps, BLOCK_SIZE, num_warps, dY):\n", + "def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):\n", " return TensorProxy(like=dY)" ] }, { "cell_type": "markdown", "id": "b70eba5f", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "With this, we can just register the additional operators:" + "### Register optimized operators\n", + "\n", + "Just like the `apply_rope` before, we can register operators for the optimized forward and backward." ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "id": "f8f1e77e", "metadata": {}, "outputs": [], "source": [ - "unsloth_rms_norm_forward = my_ex.register_operator('unsloth_rms_norm_forward', meta=rms_layernorm_forward_meta, fn=rms_layernorm_forward_impl)\n", - "unsloth_rms_norm_backward = my_ex.register_operator('unsloth_rms_norm_backward', meta=rms_layernorm_backward_meta, fn=rms_layernorm_backward_impl)" + "unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward', \n", + " meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)\n", + "unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward', \n", + " meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)" ] }, { "cell_type": "markdown", "id": "2426263d", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "But instead of monkey-patching more, we can now register the kernel as an _implementation_ of the base `rms_norm` primitive defined above. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`rms_norm`) in terms of our new operator - so it has the call signature of the `rms_norm`. Because - like many fast implementations - the unsloth RMS norm does not implement the operator in full generality (to do them justice, they have a variant adding the unit offset, we just didn't copy it over), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." + "### Implementations for operators\n", + "\n", + "Do we need to divert `apply_rope` again? No!\n", + "We can register the specialized kernel as an _implementation_ of our base `apply_rope` operator. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`apply_ropw`) in terms of our new operator - so it has the call signature of the `apply_rope`. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs." ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 23, "id": "6b5c8320", "metadata": {}, "outputs": [], "source": [ - "def rms_norm_to_unsloth(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n", - " assert dim == -1 and not add_unit_offset\n", - " res, _ = unsloth_rms_norm_forward(x, weight, eps)\n", + "def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", + " assert len(x.shape) == 4\n", + " res, *_ = unsloth_apply_rope_forward(x, cos, sin)\n", " return res\n", "\n", - "def rms_norm_to_unsloth_checker(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):\n", - " if dim != -1 or add_unit_offset:\n", + "def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:\n", + " if len(x.shape) != 4:\n", " return False\n", - " if weight.requires_grad:\n", - " return False # the unsloth rms norm backwward only gives the grad w.r.t. x\n", - " return x.device.devicetype == thunder.devices.DeviceType.CUDA and weight.device.devicetype == thunder.devices.DeviceType.CUDA\n", + " return (x.device.devicetype == thunder.devices.DeviceType.CUDA and\n", + " cos.device.devicetype == thunder.devices.DeviceType.CUDA and\n", + " cos.device.devicetype == thunder.devices.DeviceType.CUDA)\n", "\n", - "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker, execution_transform=rms_norm_to_unsloth)\n" + "my_ex.register_implementation(apply_rope,\n", + " checker=apply_rope_to_unsloth_checker,\n", + " execution_transform=apply_rope_to_unsloth)\n" ] }, { "cell_type": "markdown", "id": "eec7c95a", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "So let us give that a try! Works great..." + "So let us give it a try! Works great..." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "id": "965ba1d7", "metadata": {}, "outputs": [ @@ -1656,7 +3847,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "deviation: 9.5367431640625e-07\n" + "deviation: 0.015625\n" ] }, { @@ -1668,49 +3859,45 @@ "\n", "@torch.no_grad()\n", "@no_autocast()\n", - "def computation(x, t_weight):\n", - " # x: \"cuda:0 f32[2048, 4096]\" \n", - " # t_weight: \"cuda:0 f32[4096]\" \n", - " (t7, (_, _, _)) = unsloth_rms_norm_forward(x, t_weight, 1e-06)\n", - " del x, t_weight\n", - " return t7" + "def computation(x, t_1_cos, t_1_sin):\n", + " # x: \"cuda:0 bf16[2, 128, 4096, 16]\" \n", + " # t_1_cos: \"cuda:0 f32[4096, 16]\" \n", + " # t_1_sin: \"cuda:0 f32[4096, 16]\" \n", + " (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)\n", + " del x, t_1_cos, t_1_sin\n", + " return t2" ] }, - "execution_count": 23, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "with torch.device('cuda'):\n", - " norm_module = ThunderizedRMSNorm(4096)\n", - "\n", - "# unfortunately, we meet dragons if we don't do this at this stage\n", - "for p in norm_module.parameters(): \n", - " p.requires_grad_(False)\n", - "\n", - "thunder_norm_module = thunder.jit(norm_module, executors=[my_ex,]) \n", - "x = torch.randn(2048, 4096, device=\"cuda\")\n", - "\n", - "expected = norm_module(x)\n", - "actual = thunder_norm_module(x)\n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", + "expected = test_apply_rope(Q, m)\n", + "actual = thunder_apply_rope(Q, m)\n", "print(\"deviation:\", (expected - actual).abs().max().item())\n", "\n", - "thunder.last_traces(thunder_norm_module)[-1]" + "thunder.last_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", - "id": "0e3e4d85", - "metadata": {}, + "id": "69a93d3d-3a88-4297-b330-23a7fff2c4b4", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ "And this is also automatic when we instantiate a larger llama2-like model:" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "id": "7fff2522", "metadata": {}, "outputs": [ @@ -1718,7 +3905,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "deviation: 4.76837158203125e-07\n" + "deviation: 5.960464477539062e-07\n" ] } ], @@ -1742,34 +3929,37 @@ { "cell_type": "markdown", "id": "b538cb40", - "metadata": {}, + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, "source": [ - "By peeking into the trace, we can see that it actually used the unsloth RMS kernels:" + "By peeking into the trace, we can see that it actually used the unsloth apply rope:" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "id": "c260cb25", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[' (n_1, (_, _, _)) = unsloth_rms_norm_forward(x, t_transformer_h_0_norm_1_weight, 1e-05)',\n", - " ' (t110, (_, _, _)) = unsloth_rms_norm_forward(t102, t_transformer_h_0_norm_2_weight, 1e-05)',\n", - " ' (t139, (_, _, _)) = unsloth_rms_norm_forward(t130, t_transformer_h_1_norm_1_weight, 1e-05)',\n", - " ' (t215, (_, _, _)) = unsloth_rms_norm_forward(t207, t_transformer_h_1_norm_2_weight, 1e-05)',\n", - " ' (t243, (_, _, _)) = unsloth_rms_norm_forward(t235, t_transformer_ln_f_weight, 1e-05)']" + "[' (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',\n", + " ' (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',\n", + " ' (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',\n", + " ' (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']" ] }, - "execution_count": 25, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'rms' in s]" + "[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\\n') if 'apply_rope' in s]" ] }, { @@ -1777,79 +3967,97 @@ "id": "0f6c0780", "metadata": {}, "source": [ - "But what about the backward?\n", + "### But what about the backward?\n", "\n", - "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`." + "Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`. \n" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "id": "7670a872", "metadata": {}, "outputs": [], "source": [ "from thunder.core.transforms import get_grad, put_grads\n", "\n", - "def unsloth_rms_norm_grad(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool):\n", - " res, (r, BLOCK_SIZE, num_warps) = unsloth_rms_norm_forward(x, weight, eps)\n", + "def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):\n", + " res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)\n", " grad_res = get_grad(res)\n", - " grad_x = unsloth_rms_norm_backward(x, weight, r, eps, BLOCK_SIZE, num_warps, grad_res)\n", + " grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)\n", " put_grads((x,), (grad_x,))\n", " return res\n", "\n", - "my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker,\n", - " execution_transform=rms_norm_to_unsloth,\n", - " grad_transform=unsloth_rms_norm_grad \n", + "my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,\n", + " execution_transform=apply_rope_to_unsloth,\n", + " grad_transform=unsloth_apply_rope_grad \n", " )\n", "\n" ] }, { - "cell_type": "code", - "execution_count": 27, - "id": "d31aced0", + "cell_type": "markdown", + "id": "219dfaa4-cdef-47de-b60c-7c7c1642cb84", "metadata": {}, + "source": [ + "Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.\n" + ] + }, + { + "cell_type": "markdown", + "id": "68226a4a-6ad8-43fb-b92f-c1e8eec6f13e", + "metadata": {}, + "source": [ + "And let us try our function using the optimized backward" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "ccc3ed63-ddc2-4b0e-bcd0-f77d66fefe9f", + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([4096]) torch.Size([256]) torch.Size([256, 4096])\n", - "(4096, 1) (4096, 1) (1,) (1,) (4096, 1)\n", - "maximum deviation grads: 3.5762786865234375e-07\n" + "res deviation: 0.015625\n", + "grad deviation: 0.0078125\n" ] } ], "source": [ - "with torch.device('cuda'):\n", - " norm_module = ThunderizedRMSNorm(4096)\n", - " norm_module.weight.requires_grad_(False)\n", - " x = torch.randn(256, 4096, requires_grad=True)\n", + "Q.requires_grad_()\n", "\n", - "thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors()) \n", + "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())\n", "\n", - "actual = thunder_norm_module(x)\n", - "expected = norm_module(x)\n", - "actual_grads = torch.autograd.grad(actual.sum(), x)\n", - "expected_grads = torch.autograd.grad(expected.sum(), x)\n", + "expected = test_apply_rope(Q, m)\n", + "go = torch.ones_like(expected)\n", + "gr_expected, = torch.autograd.grad(expected, Q, go)\n", + "actual = thunder_apply_rope(Q, m)\n", + "gr_actual, = torch.autograd.grad(actual, Q, go)\n", "\n", - "print(\"maximum deviation grads:\", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))" + "print(\"res deviation:\", (expected - actual).abs().max().item())\n", + "print(\"grad deviation:\", (gr_expected - gr_actual).abs().max().item())" ] }, { "cell_type": "markdown", - "id": "be218e9d", + "id": "63cb61ee-c791-49d1-ba5c-3fe4b5b9a9d5", "metadata": {}, "source": [ - "And here is our module having the unsloth backward:" + "And with `last_backward_traces` we can check that our module is using the unsloth backward:" ] }, { "cell_type": "code", "execution_count": 29, - "id": "ac00153b", - "metadata": {}, + "id": "cd12ca02-6f06-4d88-b5b7-25c4c27dbc9a", + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { @@ -1864,7 +4072,7 @@ " # saved_for_backward: \"Collection\" \n", " # cotangents: \"Collection\" \n", " C0, \\\n", - " C1, \\\n", + " _, \\\n", " = saved_for_backward\n", " clear_collection(saved_for_backward)\n", " del saved_for_backward\n", @@ -1872,19 +4080,14 @@ " = cotangents\n", " clear_collection(cotangents)\n", " del cotangents\n", - " t0, \\\n", " t1, \\\n", - " t3, \\\n", + " t2, \\\n", " = C0\n", " clear_collection(C0)\n", " del C0\n", - " f0, \\\n", - " = C1\n", - " clear_collection(C1)\n", - " del C1\n", - " t2 = unsloth_rms_norm_backward(t0, t1, t3, f0, 4096, 8, t4) # t2: \"cuda:0 f32[256, 4096]\"\n", - " del t0, t1, t3, f0, t4\n", - " return (t2, None)" + " t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4) # t3: \"cuda:0 bf16[2, 128, 4096, 16]\"\n", + " del t1, t2, t4\n", + " return (t3, None, None)" ] }, "execution_count": 29, @@ -1893,29 +4096,96 @@ } ], "source": [ - "thunder.last_backward_traces(thunder_norm_module)[-1]" + "thunder.last_backward_traces(thunder_apply_rope)[-1]" ] }, { "cell_type": "markdown", - "id": "26ac79f0", + "id": "2776d183-0232-495e-aa75-3b90e799c841", "metadata": {}, "source": [ - "That's it! Do check out our LitGPT studios and the other tutorial notebooks.\n" + "### Comparing and exploring optimizations\n", + "\n", + "It is also straightforward to compare potential optimizations.\n", + "\n", + "Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands." ] }, { "cell_type": "code", - "execution_count": null, - "id": "586cdd30", + "execution_count": 30, + "id": "a5e0ce05", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "eager\n", + "3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", + "thunder + unsloth\n", + "6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", + "thunder default (nvfuser)\n", + "1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" + ] + } + ], + "source": [ + "def test_apply_rope_copy(x, m):\n", + " return apply_rope_copy(x, m.cos, m.sin)\n", + "\n", + "test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", + "test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)\n", + "y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)\n", + "\n", + "print(\"eager\")\n", + "%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", + "print(\"thunder + unsloth\")\n", + "%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n", + "print(\"thunder default (nvfuser)\")\n", + "%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()\n" + ] + }, + { + "cell_type": "markdown", + "id": "08b8454f-c725-470c-92a5-56b2206af0e8", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "That's it!\n", + "\n", + "## Conclusion\n", + "\n", + "To wrap up, we hope you got a taste of\n", + "\n", + "- Getting things going with Thunder:\n", + "\n", + " - Applying Thunder through `thunder.jit` and\n", + " - using FSDP by just wrapping the model in `thunder.distributed.fsdp` before compilation.\n", + "\n", + "- See what's going on inspecting traces:\n", + "\n", + " - `thunder.last_traces` for the forward traces,\n", + " - `thunder.last_backward_traces` for the backward,\n", + " \n", + "- Extending Thunder:\n", + "\n", + " - registering operators with the `OperatorExecutor`,\n", + " - defining implementations with custom forward and backward to include optimized kernels.\n", + "\n", + "Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the [Thunder section of the Lightning forums](https://lightning.ai/forums/c/thunder) or in the `#thunder` channel on the [PyTorch-Lightning slack](https://pytorch-lightning.slack.com/). \n", + "\n", + "Do check out our LitGPT studios and the other tutorial notebooks.\n" + ] } ], "metadata": { + "celltoolbar": "Slideshow", "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1929,7 +4199,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.10" } }, "nbformat": 4,