Skip to content

Commit 2caf5df

Browse files
authored
Improve memory usage for stablecode-completion-alpha-3b (#1019)
1 parent 8d19015 commit 2caf5df

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

thunder/executors/torch_compile.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ def cuda_device_checker(*args, **kwargs):
229229
prims.reshape.id,
230230
prims.slice_prim.id,
231231
prims.transpose.id,
232+
# div and erf are used in GELU and are fused horizontally with RoPE when
233+
# parallel residual paths are used in the transformer block
234+
prims.div.id,
235+
prims.erf.id,
232236
}
233237
torch_compile_cat_ex._implmap = {
234238
op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops

0 commit comments

Comments
 (0)