Skip to content

Commit c2b5dde

Browse files
committed
feat: deallocate the memory
1 parent 7fe06e1 commit c2b5dde

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ struct JITCallScratchMemoryLowering
6868
auto funcOpInterface = dyn_cast<FunctionOpInterface>(funcOp);
6969

7070
auto &fnBody = funcOp->getRegion(0).front();
71-
rewriter.setInsertionPoint(&fnBody, fnBody.begin());
7271

7372
for (unsigned idx : rewriteScratchMemoryIdxs.set_bits()) {
73+
rewriter.setInsertionPoint(&fnBody, fnBody.begin());
7474
auto scratchMemoryOp =
7575
inputs[idx].getDefiningOp<triton_ext::ScratchMemoryOp>();
7676
auto outTy =
@@ -93,7 +93,22 @@ struct JITCallScratchMemoryLowering
9393
allocOp.getResult());
9494
rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult());
9595

96-
// TODO: dealloc the ops using gpu.dealloc
96+
SmallVector<Value> deps;
97+
Operation *lastUser = ptrOp;
98+
for (auto u : ptrOp->getUsers()) {
99+
if (auto gpuLaunchOp = dyn_cast<gpu::LaunchFuncOp>(u)) {
100+
deps.push_back(gpuLaunchOp.getAsyncToken());
101+
}
102+
103+
if (lastUser->isBeforeInBlock(u)) {
104+
lastUser = u;
105+
}
106+
}
107+
108+
rewriter.setInsertionPointAfter(lastUser);
109+
gpu::DeallocOp::create(rewriter, op.getLoc(),
110+
gpu::AsyncTokenType::get(rewriter.getContext()),
111+
ValueRange(deps), allocOp.getResult());
97112
}
98113

99114
funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs);

0 commit comments

Comments
 (0)