Skip to content

Commit

Permalink
use from_bsym_swap_proxies
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Aug 16, 2024
1 parent 86b0c05 commit 091aa9b
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions notebooks/writing_a_trace_transform_cpu_offloading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -250,17 +250,7 @@
"\n",
" def _replace_saved_tensors(self, forward_trace, new_output_map):\n",
" return_bsym = forward_trace.bound_symbols[-1]\n",
" return_bsym_args = return_bsym.args\n",
" saved_tensors = return_bsym.args[1][0]\n",
"\n",
" new_saved_tensors = []\n",
" for t in saved_tensors:\n",
" new_output = new_output_map.get(variableify(t), t)\n",
" new_saved_tensors.append(new_output)\n",
"\n",
" new_return_bsym = BoundSymbol.from_bsym(\n",
" return_bsym, **{\"args\": (return_bsym_args[0], (tuple(new_saved_tensors), return_bsym_args[1][1]))}\n",
" )\n",
" new_return_bsym = return_bsym.from_bsym_swap_proxies(new_output_map)\n",
"\n",
" # Replace the old return with our new return.\n",
" forward_trace.bound_symbols.pop(-1)\n",
Expand Down Expand Up @@ -354,7 +344,7 @@
" new_out.append(offloaded_tensors[vout])\n",
" else:\n",
" new_out.append(out)\n",
" new_unpack_bsym = BoundSymbol.from_bsym(unpack_sym, output=(new_out,))\n",
" new_unpack_bsym = BoundSymbol.from_bsym(unpack_sym, output=tuple(new_out))\n",
" computation_trace.bound_symbols[unpack_idx] = new_unpack_bsym\n",
"\n",
" # Now we again find the first usages of offloaded tensor\n",
Expand Down Expand Up @@ -417,7 +407,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -431,7 +421,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -459,16 +449,16 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Peak Memory with thunder : 2232320 bytes\n",
"Peak Memory with CPU Offloading : 1396736 bytes\n",
"Allocated Memory after cleaning 8.192e-06 GB\n"
"Peak Memory with thunder : 2418688 bytes\n",
"Peak Memory with CPU Offloading : 1583104 bytes\n",
"Allocated Memory after cleaning 0.000186368 GB\n"
]
}
],
Expand Down Expand Up @@ -629,14 +619,14 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Allocated Memory after cleaning 8.192e-06 GB\n"
"Allocated Memory after cleaning 0.000186368 GB\n"
]
}
],
Expand Down

0 comments on commit 091aa9b

Please sign in to comment.