diff --git a/notebooks/liger_kernel.ipynb b/notebooks/liger_kernel.ipynb index 20265b116f..8f1764c6d4 100644 --- a/notebooks/liger_kernel.ipynb +++ b/notebooks/liger_kernel.ipynb @@ -367,10 +367,10 @@ "\n", "jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))\n", "res = jm(inp, inp_pos)\n", - "ref = m(inp, inp_pos)\n", "\n", "go = torch.randn_like(res)\n", "(grad_res,) = torch.autograd.grad(res, jm.get_parameter(\"transformer.wte.weight\"), go)\n", + "ref = m(inp, inp_pos)\n", "(grad_ref,) = torch.autograd.grad(ref, m.get_parameter(\"transformer.wte.weight\"), go)\n", "\n", "assert_close(res, ref)\n",