@@ -68,9 +68,9 @@ struct JITCallScratchMemoryLowering
6868 auto funcOpInterface = dyn_cast<FunctionOpInterface>(funcOp);
6969
7070 auto &fnBody = funcOp->getRegion (0 ).front ();
71- rewriter.setInsertionPoint (&fnBody, fnBody.begin ());
7271
7372 for (unsigned idx : rewriteScratchMemoryIdxs.set_bits ()) {
73+ rewriter.setInsertionPoint (&fnBody, fnBody.begin ());
7474 auto scratchMemoryOp =
7575 inputs[idx].getDefiningOp <triton_ext::ScratchMemoryOp>();
7676 auto outTy =
@@ -93,7 +93,22 @@ struct JITCallScratchMemoryLowering
9393 allocOp.getResult ());
9494 rewriter.replaceAllUsesWith (fnBody.getArgument (idx), ptrOp.getResult ());
9595
96- // TODO: dealloc the ops using gpu.dealloc
96+ SmallVector<Value> deps;
97+ Operation *lastUser = ptrOp;
98+ for (auto u : ptrOp->getUsers ()) {
99+ if (auto gpuLaunchOp = dyn_cast<gpu::LaunchFuncOp>(u)) {
100+ deps.push_back (gpuLaunchOp.getAsyncToken ());
101+ }
102+
103+ if (lastUser->isBeforeInBlock (u)) {
104+ lastUser = u;
105+ }
106+ }
107+
108+ rewriter.setInsertionPointAfter (lastUser);
109+ gpu::DeallocOp::create (rewriter, op.getLoc (),
110+ gpu::AsyncTokenType::get (rewriter.getContext ()),
111+ ValueRange (deps), allocOp.getResult ());
97112 }
98113
99114 funcOpInterface.eraseArguments (rewriteScratchMemoryIdxs);
0 commit comments