From 24f03258610cf4903a2e50f01e71b8a5ea5b4e2f Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 30 Dec 2025 16:39:47 +0530 Subject: [PATCH 1/4] Add ROCm/HIP support for TurboDiffusion ops - Add HIP kernels for GEMM, LayerNorm, RMSNorm, and quantization ops - Integrate rocWMMA for matrix operations on AMD GPUs - Update setup.py for Windows ROCm builds with clang-cl - Add platform detection (CUDA/HIP) with common abstractions - Optimize SLA kernel config for ROCm (BLKK=16) - Update .gitignore to exclude build artifacts and IDE files - Fix distributed utils and network files for ROCm compatibility --- .gitignore | 20 +- build_ext_log.txt | Bin 0 -> 235342 bytes pyproject.toml | 2 +- rocwmma_lib | 1 + setup.py | 155 ++++-- turbodiffusion/SLA/core.py | 42 +- .../imaginaire/utils/distributed.py | 2 +- turbodiffusion/imaginaire/utils/misc.py | 6 +- turbodiffusion/ops/bindings.cpp | 11 +- turbodiffusion/ops/common/common_hip.hpp | 108 ++++ turbodiffusion/ops/common/launch_hip.hpp | 45 ++ turbodiffusion/ops/common/load.hpp | 57 +- turbodiffusion/ops/common/platform.hpp | 208 +++++++ turbodiffusion/ops/common/store_hip.hpp | 78 +++ turbodiffusion/ops/gemm/gemm_rocwmma.hip | 75 +++ turbodiffusion/ops/gemm/kernel_hip.hpp | 523 ++++++++++++++++++ turbodiffusion/ops/gemm/kernel_rocwmma.hpp | 332 +++++++++++ turbodiffusion/ops/gemm/launch_hip.hpp | 66 +++ turbodiffusion/ops/gemm/launch_rocwmma.hpp | 80 +++ turbodiffusion/ops/gemm/utils_hip.hpp | 130 +++++ turbodiffusion/ops/norm/layernorm.hip | 70 +++ turbodiffusion/ops/norm/layernorm_hip.hpp | 221 ++++++++ turbodiffusion/ops/norm/norm_rocm.hpp | 428 ++++++++++++++ turbodiffusion/ops/norm/rmsnorm.hip | 64 +++ turbodiffusion/ops/norm/rmsnorm_hip.hpp | 166 ++++++ turbodiffusion/ops/quant/quant.hip | 77 +++ turbodiffusion/ops/quant/quant_hip.hpp | 194 +++++++ turbodiffusion/ops/quant/quant_rocm.hpp | 261 +++++++++ turbodiffusion/rcm/networks/wan2pt1.py | 11 +- turbodiffusion/rcm/networks/wan2pt1_jvp.py | 6 +- turbodiffusion/rcm/networks/wan2pt2.py | 9 +- turbodiffusion/rcm/utils/context_parallel.py | 7 +- 32 files changed, 3368 insertions(+), 87 deletions(-) create mode 100644 build_ext_log.txt create mode 160000 rocwmma_lib create mode 100644 turbodiffusion/ops/common/common_hip.hpp create mode 100644 turbodiffusion/ops/common/launch_hip.hpp create mode 100644 turbodiffusion/ops/common/platform.hpp create mode 100644 turbodiffusion/ops/common/store_hip.hpp create mode 100644 turbodiffusion/ops/gemm/gemm_rocwmma.hip create mode 100644 turbodiffusion/ops/gemm/kernel_hip.hpp create mode 100644 turbodiffusion/ops/gemm/kernel_rocwmma.hpp create mode 100644 turbodiffusion/ops/gemm/launch_hip.hpp create mode 100644 turbodiffusion/ops/gemm/launch_rocwmma.hpp create mode 100644 turbodiffusion/ops/gemm/utils_hip.hpp create mode 100644 turbodiffusion/ops/norm/layernorm.hip create mode 100644 turbodiffusion/ops/norm/layernorm_hip.hpp create mode 100644 turbodiffusion/ops/norm/norm_rocm.hpp create mode 100644 turbodiffusion/ops/norm/rmsnorm.hip create mode 100644 turbodiffusion/ops/norm/rmsnorm_hip.hpp create mode 100644 turbodiffusion/ops/quant/quant.hip create mode 100644 turbodiffusion/ops/quant/quant_hip.hpp create mode 100644 turbodiffusion/ops/quant/quant_rocm.hpp diff --git a/.gitignore b/.gitignore index 3645ff0..9938bae 100755 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,6 @@ -.DS_Store -.vscode -__pycache__ - -# build results -build/ -dist/ -*.egg-info -*.so - -# checkpoints +venv/ +__pycache__/ +.vscode/ checkpoints/ - -# outputs -output*/ -*.mp4 \ No newline at end of file +output/ +*.egg-info/ diff --git a/build_ext_log.txt b/build_ext_log.txt new file mode 100644 index 0000000000000000000000000000000000000000..7fd75b3aef08f4ee5dccd94658093f339bc49400 GIT binary patch literal 235342 zcmeI5X>%LNv99~GBhD|lUvTYXOR|TgC}}8JbL0q(mJ!3+uGTpr%+M4Ui6d?%DUq~K z%um1fzE34e)eQm!z;3X+IT;85y?5nOS-Dl#|NFo9v)frW`z2e>w)Foh`<(4%@3Wok zezuXV>+?C=*Vaz9t+NMT&#n$QyOZte+OF>2(A8YAoqf=e*=$8;cC$`4ll@J<=lRBW zdjBW2g!iwhP2cGK+k&Q}HhmP#v~W$={FMDi_B&m3EEqoqXgk77C;MO7yX>#oHR0x^ zU|-RPtM>z5HnUA#-_c!rI?C_P*Ymuqqvvx?M|UslxxCSE($W1JI>WPO!#l3(8d`8y z*LAXcfhMnYc1=$Or;$E)^`2c}kmoJy_~YTxJFtE>K7SAW(&RNvM=y<}g;rVu-K=-4$`0_{F)r1Uk~PF z!zj{ce;_9tjjTPqDUC8<=up(OQQY)~O|{ar2D_HG?uSdc<-3^c>>LXYR3kGk?e zQ0+?=%;}StJ;q}uFQPs&8=n3_SJJ=d)F#6bygDa7Wn`ou?`!Wca4(XBmLk{sWzvl9 zGO09|{XxZ6Vyk@%M98Q+6DgQ<;j03MC9jP}!`8AZU>uKEYMDtLr1=?|O2=33Ze zT%zwRg=_XS0`jbOjW-MGb5)EtjF)*XSqVI7^3L=QbX9sLm>t{~jfWyFz-1fxslUy| zamB_9MgjT{QX1gd%0s-JZds)M|(DYrIDz=FAo}- zZV#2({WQKEIOmy0x<$RotPlDG($;h}IE&-x-Uic4onKdq3%m(Y%dM?)_BU8&w%f2bg_OF3^;VZKvlMa7h z$HvjIi=4~IF}QDnOP70SMlSbU?wzk3oAhmixg-*E+_TC)y`Qv#9^^V?q}7 zI1C)L+|v>9suukE46Ax+(Qey!uzgpULNn&wUbEDqciZ}WGz?{{6y-u#I5`4IiK`T|KTY_sTjw~ z@L_7rpeLOzw)Z@=`GbBLzvxZRb0b%AFOFjMzP%`$tlt^zMOeD>498e#S27466ZFCJ zjTfPtz0ol|tmstT5RXDape=NJ{$SN*WVSu?gP}~(oj)YaDeU$WXM1`|r`m&yOOZwXtJ>hMj z)~e`x82ZHDwSN${=0(l^H*__sTHIn*d_z>?nvM|5=&Fry9kF-DztMa;Nh>C42QArZ zKr7HN%%6e9P&ZD9!E85#K9+gX2@)R<2%`@fC1%Z9e1(2=(gx&Nmqh)OSjMOFPG{csy$iob@m-ej)^A-7 zZg>B^28rcQFg~p6ZN+>Eaw!<|Q~rvz+3e3^E8}yJN%>44<9Ga2O=&$jWY>e=D&hF- zt$!+L@lS!7yuN0#6F;Eo#hjl+ZYKMK=*6fEr@U>LZ^fFlBD`ZUn{!^hm@I|jky+^p z@c$&RGzFTj_%`J?%!&KpSK_W`gQ_WT&4pe^2Gd#LXbL1JDP;12IMdl-YYKF59FQPA zv0Km&vF~>^f)#tBWfew0^^*Er8R*ZNUtF|*&*Fn9`YFo6P$v{*J6Jw9Q3CTDwHcb@ z9w##5d??9VWDyLqpF24}4>^g$toTk=-ipWK-l+hNCP0FWAR1==2_!=`kTnsu#IgjRwbiL@gEb1Y4~YV1tUjv5tJvS=q!&rS6M=Gw(0$p$B$K`f)F{n2lh zgRQb&!ZzvW65DShuT+nXQCOJ1zET_|mShQ3GtO5F*-6eE9#LWLREiLo*uM^g5G4-Je5cL3YzG}t`Z0J^EsM06^F{};Pt~o z4~-dW<+-zdc=}}!8H7~hVkT92_(u1>^Y=y?gk)b+Wvdz)GaAk&N|}>UR2y|KGJo3` zTNQ-u9AWeJT=_oUP8^){8%@I_a#3(otvBSNV5MqL$VI`+H0=G0f}5)C*!nd{%XCM_ zFYhU>ADYys%a4rpL({L_10Av4YH|&l2_G}IsbYLF=|{xIvUtn)w$WD0`CDg)qMcP$ zpoUNt)>#v|G|t4l2t2y4{zk1>tO{l`l!vJad>NsN(xf~m8Xwk`$rEdYm8VWSH8Jw? zsFQ4CLwRbFF;Tx{nFdsC%>*&8hGC>Um$x^oxmrBb&*MZ#nXy;}P3woIUn8#$y825ouDx`zN=dfvQCTRWre2pEj+SwG&Z-lV{@!0Dh3e}u z21Q)6sm+`=m_AiORv(3;blz899CeP$_oDh>VYFc`BQqMy@>+dR^m7_eF?KTYl&|_x zA6)N)E?~2ZsGq}1UexPr|0vW?vFSouKLx4zTP5Y%>GL#Ckk#BPD{U^q3^BvRbZ3aQ zlT%VQRKyWSs36HK^`%fZ!d7j9r<|8qSCT52ju~5HfwgPnjEy2@Y|l9+aSWKE@`%+5 zKzDNNBlys*-p{*+grt8(o-ir@ymBs43 zmGwch9wIbkYhBddH5~}+qZs-jE38m`{^-I|ZF*4FM>*nCjAPYW-u1zguRCy2rN^{j ztdD9`eK+ZGQE4_E2wfb!_Jmb~VuRi!# zO@TEAtK`P|poz8Qs&sNhll9T7AG)-P<;8$6u7X-6G){*sG(O!t4 z&bmDB%hCofUPQJ!$m;7qR4HZc7P9jB|5Q3k(j-||J#_U^ig_(qi5FE%0y0~VraH~N zK5CIwLnORPEuubX`bQ?>3Djx62$Ui#abNwaT3R(p3Qh;Y`nb`!kKTnmwAB!pD|a#Y z@lX&^1)bc54}z+!_SYm%tQxlZc<@kCV_Dh~xh3nBa$d9OH7zLXqa0B!)`TWrg7@>E zrF&R?u=J11yFns;3cfU}?{G0>;z>BE#=)HR@>XS?)HC~jr2;8m9~9X5c-@hNtfzO8 z#;dAftB*?kJ;*Ztu~Cjhy|Of`4=$pNtLi`Gm;BaiS2D828FFPQ+9WMx8Z@qtV`G-# zMi`|p5}m4rtUd}wsm}}mG{x=G^07(1se0(@r4(LKs+g^WNDxuvi`1u}R`n3pM=`v- z%-ybH?bAA;$oHZ-@m^I>)d!PRGPkj|nn-Gsq+a#V)kmov>AO|ZsXj>J{Hzt(k4W(H z(Z@PrVTO!V1inZkkW~h(2Cn+3v#eN^dH#rUm$!JEB-^Tntv)KfmtKHhZ~KDCw|=h+ zT0?c!GU}rfa9D3*&Ov#fUlSl%5zp`+>w_j`_Fc^QS@rNuubw3%%$`Qgk6AZp2IQN^YJ!YI@59m6 zY+j$P#v9Q5&H#-m&-L ze^Z{Nt=titnLrjzH@l|KN&Y$R`J(oFRThb1mK`yUsl1W>OTET%635Y{EUgdiSV>r^ ze#Q5TB3-Io@4byW2zf|*(%~(OfcVCC34OG!l_9Iv#M;OSW|NTJkM8a^NjaOOjT}H$ zA~AJzu8oqCppEb+De&9WF%$f@`oVeF<(=%Zxb9ZAkX@0d;BAmT)Dbz9e%s0Zu6^bU zbTvX<9pFje=xq4nX;{#AKJ0hHOaoSl?FJtczLxirYGvu#PV4$hzyTR`F?%16+?sG~ zKAmb-2E!}W#$G=hJ7u}0jo82=OK-u(Ea#VPD?fY|+*s9A@WjS0#|7==0rNzekN51F zpj!jUQ$LCiraul=vzGlRIGPx3iR+}`6Qb^h=@ib_=b_&86aWsBdQ&yFe>d$Q!IUa~5y1)Ux3{p{L}k7%Z~j~daMsM#E;FRbYcPaK!v zfws!K>h+q@Qm=`8fD}L~;yXRNByjyU+3#$%Q6j9!kH@NGjE?0v!DT{`8`?QeCT=Ke z3m7&=^@-{0gjGT`3YsR$dh%_nlJ?%m^L@OX+>mnMT;4}}u?kt&PHwPB4%r-Kxxr;a zUN*FI^f!;U<#>@f@v$^lv=a|_#a1+G;{P{nl*L6GrGw>wl#>hXq`&(K@CtX8d1dPX z<0~jTZxre1+>pkIcH$y06U&bfZIp^f${i$64y%gHfpYS(jTm4S)|Nbn#1xkT7KpMG zb3&>6Sq?KUEt0HYtBsseWl!<(qsR|!mCgai0qqqkd{I^jW6g?@d z7;+%(DCUe=b-D6h6z7vXpR^SlSW(fZ@>|V2jy+^Y)a*nNAMMmByjL6r z`x@(dDOgtcTB=!Xyqz3Tb?<1SUNOk$Le17P6)v0`+Nxz@Z|rHjEL*YOIbfQj^Hxrn^?rdw{$Y*|~e898!kYtEDACQ}FOX&}<4e?B=ZibNR8Zolz>YHKXo#_MVJHo2CWb zwiVgOk?`%(>zmSdiIXjdUW-nGtj6AkgExcD!Z9o~j(t|&X1;wDcFi=oyNpkT6}!L! z8T#mL)WRpy{4wZkmYG-9rkifxK0EYXK@2~94mlUog}7S?&Czq~)6UVO5f|*`vZm z)cBg{?<(f06*;6h$E}KNwP|Wm&M`ZxYI;p8LbD!pS6P!g(g(MK27tYPtYdF>ZuYZb&BF$gb=6y8dtJNw;(b+9h@NKZg;9x~cd_sZLWx)l|mOHcC^n zd<_fbMiSSjk|xzwUZ9T=K`Ju~4%nxsNw!ZmwoxV&$I_LcZQ*gX7;zQjXB)9%aaHWk z@{-bfNRpOqi?yXEy<=xNcuu=N*07qPoPhtAVl!aK<_j=LJy{z3+pn+g3^M zHcIU3a*|wLM|nMp52i1xtD9pm3YI4JbH*n!RfxTAm9p8Kr^n^#Ch~!;0#pSbxF09- z0sn232}S=OD?KC$QI)sPrA6Dhgt*CmsS(34;GB@>gmy}dBV}Ri$h&8Ai_3a5z=v~1 zkt^EC7x|IH<|tRe8_pg5yup4OHH{->Dlv+@m-WF+iE6mR`C|}&jOLK(+uBs%#ba1a z6evR3PR_RxljvA=eD9ned7}TnojkIx-a-_iObxAzfeGs^Wc-$0b?+kJqiCT=?K@&; zN&J$pdQ(kyP9ATro#(NU-)^IvA%`C8t?ID|yf~-ixudu}idUwukLBrJCfYeYDxZB0 z?%Y=?NFEmUh1F78`f6>Ifq0snJqgTnnqN)Lh;p!uwhfvXv#2f}J56d{ycrJDmrK>f zWSc0Nx{^oeWU-D_tb0?g4>uFiRcWGBE3&!G^o3=u)JQ!aVoFIeQ^$wOUR1>(CcPmk zUB04pGjS5N^RktTM(Gbtl-YQH;DO^3zP{;EwJoxGdKj2Y_sG{IJbRvFGoUHfsp!h< zLiYMvkW^Nq1iKpE-Ijc^YGii&Y;wDaP|ryQ7wfXnBbZOkTqr7aP*ZnDviX;A^sH4` zn+ZGq6Xv70o1UqL+;Xqplv(!Fl`A_3qm3g@6;{IZ__Zw~&S(p<4OYSQ=s3x^*_DjJ zKa0Hia7Jk}njX#6{x$_MV+nHhe;za(&r!)DwoNwDm zbLtxZQ$Ccl=5$>=s=rgszq#z&?5A3@0c;HN7IJ>LLg%7i5*$Y)IT+-eXt5*Id%A1X_4dy1!y8}fzBj_{@aM0q<@I*mqI_1GAxl4u!@8T2F@$6kvHNDz4D7g1tM z|93@?UG+WcKQC+lv3SY6wOz>-dd~;_axHf~SD(Raf)}^AcU^LZ%BuCgeQ^L52aVxf z$x4%_K=nl1RFoj@yQ3a_pkr@zj8ziw3v?wT>F4zD^Qa?h)Nez(8Fd~De-AZwoJRX1 zCruywcSS8`b}%kPN&iMvJZl>->uI;Lh3tx+@K(>HX8exEAGqi5+CKC42QAs^Nx;!uH@#M*xf!bsXw6}uGb@JFPan(H zDYj@u?V3baA4gN^inGxr7#aB8|GJOo3%wOw>`eG(=H1a3SM^k@Z%M5lgJvd-EvxDm zH}quSOe4&UYHi!7@ijGH@|@WIsD)!~fPba;IHiglQb>;Oert^%o{AeD z1YG_Qo_fDW*L!Lymeq4%cPU`_o^TECK#vEquwLmGyc-?i5qqNXh21~sI_UaD_uf(a zBVV&Ra=$%6^gP-V^g^r+@~2s|0&1BJa7R>J7p3R*`Bq=ru->Ea2i^Z!`(#qA1fLK( zC;22FvfH}yFg%4Egm{Ffe$YMCw8w($Xq!3KXp%^2G6&%huJ7oX$ySAXjLUfPCB5~e z4SGpu-s#NS!c&n0R-fNiXEA$`)vOH4e0zR!D4Z7J>gXM*ileY3I_fv}+)m)|WS_zL z>*9O3w2o2ucIe%=1Mg3(rPX-?7#Wcsb3Gkmt|#;F34h2G8<)^9@^DR}+f1*mUCHJC zEIf-0-xvo$)=*ba28qqp>r*t86<*qXdaNMj`T@cWAzK$Oi$cOexh{>}NIp1*ObNJ%`UsJ_ zx3o2kx_=J$U@_!de;W5Yb=PDcnU$5t*=f|h8K?_APp9okxo{d?i9A48vpil(!H$Xz7=*^^Qj-dcCG2ArUfTmU`u_`e2GZeQ5d|G z!y)f8y#GP}=|jIn4?+JOO_PqKAhv^zAn*}emj_7dE4gW~*-_kMP!#DUC#WCn|Y_FOM&9YS!m4%lo7?X4iX! zuNcEiHu*`5+#`H3?&uM|$#gsV+GH{BWaF4V9#pWt6H$1ji zh=qD=kNhE@z1N=~+*OVj@$~ZZ27UG(bLP_wOONfnmQIV@VaC(6<_==*rX1V$#K?-8 zeS6yNu|1FNd2Ek)Wk|K@&1cqQK6oiXT@q=9I<+n)SX<>*Qz-ZO(JSXZm5k)9o37b%H^Sc$w#IF<&Mr- ztfjZ&yhy*8sd72}~4aqul&xf=FtUJg|y_j9~+veG1N7)*)pRf+JKs&G>U z{jIzgujIdYq)a||?vMIXJsv9F(Wj#|7{1Z(qTcnOmnr`=9C?@hDXb0k?^1N@r{j40 zlh0e5PR}K{>dO6e~|I2x4SZ>IvgwBAs}_<72^ST{N1cQ6evOO}o&=bqFi=bA-VD!TqQsoE4exIAO6z(LWx$;S23 zK0ii{e0}k#bcpZl)nq_TT))T*jtoa>;jb741s1E{FftZvm&a<=gUC(P0uhx`@k@ zSbQR?9!U}(i{@CRcm-YtDP(-{MA+X7oO7(>$h9?TQK?0S*GXEG;Hq*fwdkfS+{thR zyjx$POHV2bTYAz>+2G@~Y0|YiB47U+Xi~qYWIn@s9<?!J7 zh2L}B{^jzl8hnrAQ*=Lb%yV!*vw7Wa%P-%gtS0v}7suD4Z`A$F$zQo3FWz|m%81b- zF5S-@@sj^{Kl8~-Cp+?mQFksz@5a?^t|_N&`hI5KQtSurXFe@+*1i8t^!|H3bADv= zd}cD4J)b#0GFQFb96tIkzk(B95abFtUsUGz5jWd{MwM*fA=#dsqJ+$&`aIVT#dhF*i)M0hjBl1 zo<>cfi~E_~&)mtpZiZFl>MHm4Y}qF6XGXKBU+;H$R*gJ!KQp7X*UiXh-i^Z(p3m%l z=A!4Jf3JQEIJGF>7WXqJyyUkMj@-{&r=Qt9nt85libvC_(nPAbpE*fw_cObnd6Mc| zc`xPo+fDIjdfkj7&6+|N_sdk{m+^e&eD&=#${E*BPtG-^AGx0yn!BI5DCNuSonFIm z+I&BbPjMm4MH^Hxk}TgQ)n~Rj?n|oSvaU?i9p#!HsM6)aM75fMwm74>pSt7ZHV<#f zWL0sE{)++6UborS5gd06UI?{4=EJYoZH9}BH6Xm6M$%uro<`A+l=c#i-2Yss|Jf@y zB%WleDSD`9pXce)463;QIZ18zKfC{VdVZN`8%^6z9araf|8t~>{Wpa!?tgawv*$U- zm4`jgIlsE?B&$BUJS)og@_X)ou5Nw6;a*=>XQS@H=(Pgf|2%9i$T+bI&wuv(=e$4S z`nrK>%r^Vf z{m+LX_hM5PAS=;-R)!M&g!Q~Xhksi)+%k)Hbj`B9)XLk^m8)UTR_Dcr+}Fo!Br-bI zf?b&~jt6{j{?f221LSc1QQvv3r9kcVJNk6AM)NoN&HHJdh9mE?KV?7Z%)iH4W7_>S z>1VrwshS%Jzx?~Wsz!S z_nOj=JpUO<>-o=R`8nNRuRI&ar|9|5iCS!V{`2tK-|l~Q|8vn#)W4TB^|~^%Qj79! zasP9|OMWZi$ol$T$)*J;sw4{D}(d#@HZJ4U} zn&Y4GI?s99Hia&p|LpnCp8p)!-qpU2 zd{73~cE}f`F4~&%3N0IpqaCfyyQ6Hw-H@9_g#Uw%t?SrcxN|)_QhwmPR^*t^W&`cp zqZ3`rsL_HPomQK_1k8TPy4hpl@4f!9biwhAFuE+9w_+JrI|ysQtTY zqh+Y>_u6_hBOfr<5Iht%_5_dd8CijsdhqTCeC-KOYuOL_9_X2S`oM=TwYS&1C(frF zo^7xbGHZGAs_?t1``~V#bD(2n*dFwr_dZ-{*@?S4zm|Oot^cf648MOv^dzq_u4esC zw0;{}M^D-g96*jQz3V{dPPXEwz>}sUf*;FOjCUVsA07cyo$SvO;^*>!>%kvT1J{!N z$wW>^u4Lck+GDNC=T7$PvXgrvCgZ%AW3#2U?hAXL^bfUp%I1KcW4d*of4T#3Fe40n z7MGHNx)E@D{FOs*~Z-V>b4NS~@5v`)0_i2AhMZ`{!^E>%`|zJ} zthV+bn~7kOG>$~y3cHpYT*R+kyDXg3XV5s-CC|__I{Gp)vWJu%g~3ZX9HNI9-v6Ng z^x+*;Nkz*UO_PpjWHN;jn7+5B_FPsg`^Sqbxp$A2jHx}7Z@Ite?6CpQdPu$k&wB8x zZ>MQ*lS`&ScT7Hlp?&VjLfVm(-H@za58CLkiXY>8n!V~ycBh%vMd@C}Uh}gaSo=A74D&WS_Fa=!9l?skPq!&; z=5N8bI9{bUU)?&NRWQy>OrKqs_dLc34v~hB>JgWMXA8Z0Uw#by#;buF2YIy2Tm4yg z;0@+UGxDs$!=JU6_e_(;QJh6TO#XOw1DHwixFpkkb&-UX;#R_uXE)TD-Qct9PV(g> zQPt`AHJmCB5Qlhk%Uv-_Fd&wQRWt-QM1p(4k{dYkTNc0aTGncdIqe&(1>Gq}>B z`%Vj)@8$MYtWRbyOveY{v*`N$&F*J* zKXcJ9GTa)|qxyP6dcXUW|kFiXz;%>9)+m@7f`4bNxJ`!|?dxvfm_Z*sFs?EEB_RM9J9 z_j^8b!b))~c`V$|T(6}@&8JvzC_j5>{E8Ig#L)Yz7Bch4L zi~7C)?q^O?+x^V$XP!hm%IjO>Z#ShrgV)U{(yS?TaX+*BncdIqerESGyPw(p%xkJ- zu$J>-dEJajdBW>vc-@Sm=b?YEehWCYDBqTT|5{OBKe<=T4Zka2oYf4EHNKyy#>tT~ zE|29IU)DT_4f)4khUz_5$NPyS&6cVY?&*q`I{Po}SLHuX5SHIcIC4L8oqlHbXcl=H zzO3J)=~QVVRou^9Mb=9}sk>BS)GHqd^TFpRvdKKJpYG@2^ ziPzKcdKzB08Lgl$%Q$&V@>aV0pA%m4TM0+*f3DO2>=lCZT={E=64&p4cB(XyD(-(y zQoF8E_UB%`6$S!f z2GXJ0Ij0M~{*y-bB!{Zns#l)5|G8gky8n6D8qwqMiswIj{&UeI(Z82$qq;2P)S`S_ z-2a^LlHW==a{qIk8jok%=DE`S&tCncf8~r+{lxvxNovfc^MN8~9YxZ1 z^}nKc+Ol$&nU7BPa*VEB($?se`-W#7U(y?h&izxBVu{=F&Q%>d8MSl&v-_WIm8LI& z6XBPRa_QIgd!SkgaA_yIB+g+~%?IKtD);PZ^`-+I?TD-3;!WLabrw3=p{_j8_1r~$ z)>ZBA>l&_D4fnw5#p(@S=eZ~)ZO`jG55h`3R`kzRwU;~@?tiY+|LmEzMUI9q>(8t7 zs6Z2`;{N9(wcY>h^*22K+4G-0Uoh&)<^JpS*SP-~%lJ>(PqL=}ovV~;G@YKDbN}=2 zgmFIgX0>Z< z?tiY+|LmU5JXikOkop@=m6KGN&u&V(ABOnqcE}s}pm=L1`=H3|w!V8h6XXB*g*e(# zJa$KQAa`{=*$3OYx}%sckzKL`c0$g;Y@qr|i0oQ+f?YRB9D@4%ykk`}IC<0iY&JxQ zx3#V>b^2E|F5S@firR7<;>J5_&5Y&}Q=JlN`1|ah=3aKQ=W54}+6g?&AbuF0OLjvy zd!lXHNxO+%AL#rW9oyI2$*<_B=0B@<5{ur^_19GG5?<=++2CPI=fE8i@2>8=r{7mP zv!i?M`7e7%=hfQTU=_})jZ~Pps&7|tfyHk2NcWI$@uQ9|X#c49Z&?1Mdzmx3Ca5~u zobKQaK(!P07qh2&_Y1|@pXmEkZMd5~7W7ZEXW{%~9l4kN5Vr5@%;T{2L{~jk<<5t? z=cTT=9q#{LzxP!sVo|?$1S{|Bzw)*21iR_^i`nbVmG8y*87oYS`E(1_(HndHa zgHb)c4FAZClN7MN_gXkYl0dmQD&G}Vcf}z{hGiXpEQ*rl0rrZ0mAjtjdUtWpdU}Ma z-c}64z}fo1bM>pd>`St)qrN{3&qcg7O~xS=9!rKi6l{@RNH`?;@5McBh0FNAEF_)m za!=Y{5q92&@nBVb;!wS&ll?u&LVCiyMvR0#_GZJE@n}Kc`LN&Bh{T9?pli!vc(Qhl z(=yt&Rk~eG7`2CSG-LP0;OH`tOvX8}7mdf!=hv>uI6Ag#OTG3>uLsV4)llpRx+1x{ zTN)G7x!7ORk(JQb(|*!<1|3_{+jiBn2A|7o$LR_iRZK5H3a2yn?j`A0a{!i9>Q*S*@rpyM(Xm-L( z&wH?>7Dn&Etfa=VWD?r@qV&_}Jv(wT`!x^CdCH=yaA5cc2lpk%AIgS*Ej*yHV^iN2 zwroV(6>lEoatZEfANk3MX8s*}5`Gp|OKgWstQw2(KOALplsOLgl-e*lRD}tvZH5O9 z4tgpJrXwyg3pskB(k!NRagWV0q91`PWTE*P@vHRvBjZe`I7?|GaN%n9ov!2&`=2nngP!2A}|H-wq1bq$OkI%L9pL7L2Pv*Jo^*o^V ej1?ULdOW4bm)JY9B_y)th`>s@q0OaDJPXTM+o literal 0 HcmV?d00001 diff --git a/pyproject.toml b/pyproject.toml index 4036059..38b34f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ "torch>=2.7.0", "torchvision", - "triton>=3.3.0", + "triton-windows>=3.3.0", "flash-attn", "einops", "numpy", diff --git a/rocwmma_lib b/rocwmma_lib new file mode 160000 index 0000000..c360d54 --- /dev/null +++ b/rocwmma_lib @@ -0,0 +1 @@ +Subproject commit c360d5484a5f2c8dacb166154dbe462c4777db5a diff --git a/setup.py b/setup.py index 2ca06f9..5d1bcbe 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ """ Copyright (c) 2025 by TurboDiffusion team. -Licensed under the Apache License, Version 2.0 (the "License"); +Licensed under the Apache License, Version 2.0 (the "License") Citation (please cite if you use this code): @@ -16,60 +16,117 @@ from pathlib import Path from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os +import sys + +import torch + +is_rocm = torch.version.hip is not None + +# On Windows, deduplicate INCLUDE/LIB/LIBPATH to avoid "command line too long" errors +if sys.platform == 'win32': + for var in ['INCLUDE', 'LIB', 'LIBPATH']: + val = os.environ.get(var, '') + if val: + unique = [] + seen = set() + for p in val.split(';'): + if p.lower() not in seen and p: + seen.add(p.lower()) + unique.append(p) + os.environ[var] = ';'.join(unique) ops_dir = Path(__file__).parent / "turbodiffusion" / "ops" cutlass_dir = ops_dir / "cutlass" +rocwmma_dir = Path(__file__).parent / "rocwmma_lib" / "projects" / "rocwmma" / "library" / "include" -nvcc_flags = [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=--verbose,--warn-on-local-memory-usage", - "-lineinfo", - "-DCUTLASS_DEBUG_TRACE_LEVEL=0", - "-DNDEBUG", - "-Xcompiler", - "-fPIC" -] +if is_rocm: + # HIP/ROCm build with rocWMMA + hip_flags = [ + "-O3", + "-std=c++17", + "-D__HIP_PLATFORM_AMD__", + "-DNDEBUG", + # Undefine PyTorch's half conversion restrictions - rocWMMA needs these + "-U__HIP_NO_HALF_OPERATORS__", + "-U__HIP_NO_HALF_CONVERSIONS__", + ] + + # Windows-specific: add C/C++ runtime libraries for clang-cl + extra_libraries = [] + extra_link_args = [] + if sys.platform == 'win32': + extra_libraries = ["msvcrt", "vcruntime", "ucrt"] + # Force linking with MSVC C++ runtime + extra_link_args = ["/DEFAULTLIB:msvcprt"] + + ext_modules = [ + CUDAExtension( + name="turbo_diffusion_ops", + sources=[ + "turbodiffusion/ops/bindings.cpp", + "turbodiffusion/ops/quant/quant.hip", + "turbodiffusion/ops/norm/rmsnorm.hip", + "turbodiffusion/ops/norm/layernorm.hip", + "turbodiffusion/ops/gemm/gemm_rocwmma.hip", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-D__HIP_PLATFORM_AMD__"], + "nvcc": hip_flags, + }, + include_dirs=[ + str(rocwmma_dir), + str(ops_dir), + ], + libraries=extra_libraries, + extra_link_args=extra_link_args, + ) + ] +else: + # CUDA build with CUTLASS + nvcc_flags = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "-lineinfo", + "-DNDEBUG", + ] -cc_flag = [ - "-gencode", "arch=compute_120a,code=sm_120a", - "-gencode", "arch=compute_100,code=sm_100", - "-gencode", "arch=compute_90,code=sm_90", - "-gencode", "arch=compute_89,code=sm_89", - "-gencode", "arch=compute_80,code=sm_80" -] + cc_flag = [ + "-gencode", "arch=compute_120a,code=sm_120a", + "-gencode", "arch=compute_100,code=sm_100", + "-gencode", "arch=compute_90,code=sm_90", + "-gencode", "arch=compute_89,code=sm_89", + "-gencode", "arch=compute_80,code=sm_80" + ] -ext_modules = [ - CUDAExtension( - name="turbo_diffusion_ops", - sources=[ - "turbodiffusion/ops/bindings.cpp", - "turbodiffusion/ops/quant/quant.cu", - "turbodiffusion/ops/norm/rmsnorm.cu", - "turbodiffusion/ops/norm/layernorm.cu", - "turbodiffusion/ops/gemm/gemm.cu" - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "4"], - }, - include_dirs=[ - cutlass_dir / "include", - cutlass_dir / "tools" / "util" / "include", - ops_dir - ], - libraries=["cuda"], - ) -] + ext_modules = [ + CUDAExtension( + name="turbo_diffusion_ops", + sources=[ + "turbodiffusion/ops/bindings.cpp", + "turbodiffusion/ops/quant/quant.cu", + "turbodiffusion/ops/norm/rmsnorm.cu", + "turbodiffusion/ops/norm/layernorm.cu", + "turbodiffusion/ops/gemm/gemm.cu" + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag, + }, + include_dirs=[ + str(cutlass_dir / "include"), + str(cutlass_dir / "tools" / "util" / "include"), + str(ops_dir), + ], + libraries=["cuda"], + ) + ] setup( packages=find_packages( diff --git a/turbodiffusion/SLA/core.py b/turbodiffusion/SLA/core.py index 430bfe0..3c56b88 100755 --- a/turbodiffusion/SLA/core.py +++ b/turbodiffusion/SLA/core.py @@ -17,6 +17,9 @@ import torch.nn as nn import torch.nn.functional as F +# Check for ROCm +IS_ROCM = torch.version.hip is not None + SAGESLA_ENABLED = True try: import spas_sage_attn._qattn as qattn @@ -182,10 +185,19 @@ def forward(self, q, k, v, return_sparsity=False): k = k.transpose(1, 2).contiguous() v = v.transpose(1, 2).contiguous() - arch = get_cuda_arch(q.device.index) + arch = get_cuda_arch(q.device.index) if not IS_ROCM else "rocm" + headdim = q.size(-1) if arch == "sm90": sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=64, BLKK=128) + elif IS_ROCM: + # ROCm: use smaller tiles for head_dim=128 to reduce register pressure + # head_dim=64: CTA_Q=64, CTA_K=64 + # head_dim=128: CTA_Q=32, CTA_K=16 (best performance at 10% sparsity) + blkq = 32 if headdim == 128 else 64 + blkk = 16 if headdim == 128 else 64 + sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=blkq, BLKK=blkk) else: + # Use 128x64 blocks for sm80-like archs sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=128, BLKK=64) q = q.to(self.dtype) @@ -195,26 +207,46 @@ def forward(self, q, k, v, return_sparsity=False): ########## SPARGE BEGIN ########## km = k.mean(dim=-2, keepdim=True) - headdim = q.size(-1) if arch == "sm90": q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 64, 128) + elif IS_ROCM: + # ROCm: use smaller tiles for head_dim=128 to reduce register pressure + # head_dim=64: CTA_Q=64, CTA_K=64 + # head_dim=128: CTA_Q=32, CTA_K=16 (best performance at 10% sparsity) + blkq = 32 if headdim == 128 else 64 + blkk = 16 if headdim == 128 else 64 + q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, blkq, blkk) else: + # Use 128x64 block sizes for sm80-like archs q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 128, 64) lut, valid_block_num = block_map_lut_triton(sparse_map) scale = 1.0 / (headdim ** 0.5) assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale." - o_s = torch.empty_like(q) - - if arch in ("sm80", "sm86", "sm87"): + if IS_ROCM: + # ROCm: kernel natively supports both float16 and bfloat16 + o_s = torch.empty_like(q) + pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device) + # Pass V in its native dtype (fp16 or bf16) - kernel handles both + qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( + q_int8, k_int8, v, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0 + ) + elif arch in ("sm80", "sm86", "sm87"): + # NVIDIA sm80-sm87: requires FP16 V, kernel outputs float16 + o_s = torch.empty(q.shape, dtype=torch.float16, device=q.device) pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device) v_fp16 = v.to(torch.float16) + qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( q_int8, k_int8, v_fp16, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0 ) + # Convert back to original dtype (may be bfloat16) + o_s = o_s.to(self.dtype) else: + # NVIDIA sm89+: use FP8 V kernels + o_s = torch.empty_like(q) b, h_kv, kv_len, head_dim = v.shape padded_len = (kv_len + 127) // 128 * 128 v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device) diff --git a/turbodiffusion/imaginaire/utils/distributed.py b/turbodiffusion/imaginaire/utils/distributed.py index e42ca42..b5d24e6 100644 --- a/turbodiffusion/imaginaire/utils/distributed.py +++ b/turbodiffusion/imaginaire/utils/distributed.py @@ -28,11 +28,11 @@ import pynvml import torch import torch.distributed as dist -from torch.distributed import get_process_group_ranks from imaginaire.utils.device import Device if dist.is_available(): + from torch.distributed import get_process_group_ranks from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes diff --git a/turbodiffusion/imaginaire/utils/misc.py b/turbodiffusion/imaginaire/utils/misc.py index 06d4236..634b403 100644 --- a/turbodiffusion/imaginaire/utils/misc.py +++ b/turbodiffusion/imaginaire/utils/misc.py @@ -30,8 +30,10 @@ import numpy as np import termcolor import torch -from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor.api import DTensor + +if torch.distributed.is_available(): + from torch.distributed._functional_collectives import AsyncCollectiveTensor + from torch.distributed._tensor.api import DTensor from imaginaire.utils import distributed, log diff --git a/turbodiffusion/ops/bindings.cpp b/turbodiffusion/ops/bindings.cpp index a87adf5..221f1c1 100644 --- a/turbodiffusion/ops/bindings.cpp +++ b/turbodiffusion/ops/bindings.cpp @@ -1,3 +1,12 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Python bindings for TurboDiffusion GPU operations. + * Supports both CUDA (NVIDIA) and HIP (AMD ROCm) backends. + */ + #include namespace py = pybind11; @@ -12,4 +21,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { register_rms_norm(m); register_layer_norm(m); register_gemm(m); -} \ No newline at end of file +} diff --git a/turbodiffusion/ops/common/common_hip.hpp b/turbodiffusion/ops/common/common_hip.hpp new file mode 100644 index 0000000..df11630 --- /dev/null +++ b/turbodiffusion/ops/common/common_hip.hpp @@ -0,0 +1,108 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include +#include +#include + +// Define CUTLASS macros for HIP compatibility +#ifndef CUTLASS_HOST_DEVICE +#define CUTLASS_HOST_DEVICE __host__ __device__ +#endif +#ifndef CUTLASS_DEVICE +#define CUTLASS_DEVICE __device__ +#endif +#ifndef CUTLASS_HOST +#define CUTLASS_HOST __host__ +#endif + +// Define __grid_constant__ if not available (CUDA 11.5+ feature) +#ifndef __grid_constant__ +#define __grid_constant__ +#endif + +// Define CUTLASS pragma macros for HIP +#ifndef CUTLASS_PRAGMA_UNROLL +#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") +#endif +#ifndef CUTLASS_PRAGMA_NO_UNROLL +#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("nounroll") +#endif + +inline CUTLASS_HOST_DEVICE int64_t cdiv(int64_t const& a, int64_t const &b) { + return (a + b - 1) / b; +} + +// Note: Don't define max/min as they conflict with HIP builtins +// Use std::max/std::min or the built-in max/min instead + +#define MIN(a, b) ((a) > (b) ? (b) : (a)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +#define CUDA_CHECK(call) \ +{ \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + fprintf(stderr, "CUDA Error at %s:%d: %s\n", __FILE__, __LINE__, hipGetErrorString(err)); \ + exit(err); \ + } \ +} + +#define CONFIG_SWITCH(N, ...) \ +[&] { \ + if (N <= 1024) { \ + constexpr int NUM_THR_PER_CTA = 128; \ + constexpr int MAX_HIDDEN_SIZE = 1024; \ + return (__VA_ARGS__)(); \ + } else if (N <= 2048) { \ + constexpr int NUM_THR_PER_CTA = 128; \ + constexpr int MAX_HIDDEN_SIZE = 2048; \ + return (__VA_ARGS__)(); \ + } else if (N <= 4096) { \ + constexpr int NUM_THR_PER_CTA = 128; \ + constexpr int MAX_HIDDEN_SIZE = 4096; \ + return (__VA_ARGS__)(); \ + } else if (N <= 8192) { \ + constexpr int NUM_THR_PER_CTA = 256; \ + constexpr int MAX_HIDDEN_SIZE = 8192; \ + return (__VA_ARGS__)(); \ + } else { \ + constexpr int NUM_THR_PER_CTA = 256; \ + constexpr int MAX_HIDDEN_SIZE = 16384; \ + return (__VA_ARGS__)(); \ + } \ +}() + + +template + void create_tensor( + torch::Device const &device, + std::optional &output, + std::optional &scale, + int m, int n + ) { + int num_block_m = cdiv(m, BlockSize); + int num_block_n = cdiv(n, BlockSize); + if (!output.has_value()) { + output.emplace(torch::empty( + {m, n}, + torch::TensorOptions().device(device).dtype(torch::kInt8) + )); + scale.emplace(torch::empty( + {num_block_m, num_block_n}, + torch::TensorOptions().device(device).dtype(torch::kFloat32) + )); + } + } + diff --git a/turbodiffusion/ops/common/launch_hip.hpp b/turbodiffusion/ops/common/launch_hip.hpp new file mode 100644 index 0000000..a2cc17d --- /dev/null +++ b/turbodiffusion/ops/common/launch_hip.hpp @@ -0,0 +1,45 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include +#include + + +template +__global__ void device_kernel( + __grid_constant__ typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +__global__ __launch_bounds__(Kernel::MaxThreadsPerBlock, Kernel::MinBlocksPerMultiprocessor) +void device_kernel_with_launch_bounds( + __grid_constant__ typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +void launch_kernel( + typename Kernel::Params const ¶ms, + dim3 grid_shape, + dim3 cta_shape, + size_t ShmSize, + hipStream_t stream = nullptr +) { + auto func = device_kernel; + if (ShmSize >= 48 * 1024) { + CUDA_CHECK(hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + )); + } + hipLaunchKernelGGL(( func), dim3(grid_shape), dim3(cta_shape), ShmSize, stream, params); + CUDA_CHECK(hipGetLastError()); +} diff --git a/turbodiffusion/ops/common/load.hpp b/turbodiffusion/ops/common/load.hpp index f72bcad..ae8ca7d 100644 --- a/turbodiffusion/ops/common/load.hpp +++ b/turbodiffusion/ops/common/load.hpp @@ -1,5 +1,48 @@ #pragma once +// Include common_hip.hpp for CUTLASS macro definitions when building for HIP +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#include "common/common_hip.hpp" +#include +#include + +// Helper functions for type conversion on HIP +namespace turbo_hip { + +template +__device__ __forceinline__ T from_float(float val) { + return static_cast(val); +} + +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} + +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float val) { + return hip_bfloat16(val); +} + +template +__device__ __forceinline__ float to_float(T val) { + return static_cast(val); +} + +template <> +__device__ __forceinline__ float to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float to_float(hip_bfloat16 val) { + return static_cast(val); +} + +} // namespace turbo_hip + +#endif + template < class InputDtype_, int TileM_, @@ -30,8 +73,13 @@ class Loader { void const *thr_input_ptr = (void*)((InputDtype*)cta_input_ptr + thr_m_offset * n + thr_n_offset); InputDtype tmp_reg[NumElementPerThread]; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < NumElementPerThread; ++i) + for (int i = 0; i < NumElementPerThread; ++i) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + tmp_reg[i] = turbo_hip::from_float(0.f); +#else tmp_reg[i] = InputDtype(0.f); +#endif + } bool pred = IsEvenM ? true : thr_m_offset + blk_m * TileM < m; int limit = IsEvenN ? NumElementPerThread : MIN(NumElementPerThread, n - (blk_n * TileN + thr_n_offset)); if (n_alignment % 128 == 0) @@ -42,8 +90,13 @@ class Loader { _load(thr_input_ptr, (void*)tmp_reg, limit, pred); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < NumElementPerThread; ++i) + for (int i = 0; i < NumElementPerThread; ++i) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + *((float*)thr_output_reg + i) = turbo_hip::to_float(tmp_reg[i]); +#else *((float*)thr_output_reg + i) = static_cast(reinterpret_cast(tmp_reg[i])); +#endif + } } private: diff --git a/turbodiffusion/ops/common/platform.hpp b/turbodiffusion/ops/common/platform.hpp new file mode 100644 index 0000000..284af0a --- /dev/null +++ b/turbodiffusion/ops/common/platform.hpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Platform abstraction layer for CUDA/HIP compatibility. + * This header provides unified macros and types for both backends. + */ + +#pragma once + +// Detect platform +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + #define TURBO_PLATFORM_HIP 1 + #define TURBO_PLATFORM_CUDA 0 +#else + #define TURBO_PLATFORM_HIP 0 + #define TURBO_PLATFORM_CUDA 1 +#endif + +// Include appropriate runtime headers +#if TURBO_PLATFORM_HIP + #include + #include +#else + #include + #include + #include +#endif + +// Stream type abstraction +#if TURBO_PLATFORM_HIP + using turboStream_t = hipStream_t; + using turboError_t = hipError_t; + #define turboSuccess hipSuccess + #define turboGetErrorString hipGetErrorString + #define turboGetLastError hipGetLastError + #define turboFuncSetAttribute hipFuncSetAttribute + #define turboFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize +#else + using turboStream_t = cudaStream_t; + using turboError_t = cudaError_t; + #define turboSuccess cudaSuccess + #define turboGetErrorString cudaGetErrorString + #define turboGetLastError cudaGetLastError + #define turboFuncSetAttribute cudaFuncSetAttribute + #define turboFuncAttributeMaxDynamicSharedMemorySize cudaFuncAttributeMaxDynamicSharedMemorySize +#endif + +// Device function qualifiers +#if TURBO_PLATFORM_HIP + #define TURBO_HOST __host__ + #define TURBO_DEVICE __device__ + #define TURBO_HOST_DEVICE __host__ __device__ + #define TURBO_KERNEL __global__ + #define TURBO_INLINE __forceinline__ +#else + #define TURBO_HOST __host__ + #define TURBO_DEVICE __device__ + #define TURBO_HOST_DEVICE __host__ __device__ + #define TURBO_KERNEL __global__ + #define TURBO_INLINE __forceinline__ +#endif + +// Pragma unroll +#if TURBO_PLATFORM_HIP + #define TURBO_PRAGMA_UNROLL _Pragma("unroll") + #define TURBO_PRAGMA_NO_UNROLL _Pragma("nounroll") +#else + #define TURBO_PRAGMA_UNROLL _Pragma("unroll") + #define TURBO_PRAGMA_NO_UNROLL _Pragma("nounroll") +#endif + +// Error checking macro +#define TURBO_CHECK(call) \ + do { \ + turboError_t err = (call); \ + if (err != turboSuccess) { \ + fprintf(stderr, "GPU Error at %s:%d: %s\n", __FILE__, __LINE__, \ + turboGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + +// Half precision types +#if TURBO_PLATFORM_HIP + using half_t = __half; + using bfloat16_t = hip_bfloat16; + + TURBO_DEVICE TURBO_INLINE float __int2float_rn_hip(int x) { + return static_cast(x); + } + #define __int2float_rn __int2float_rn_hip + + TURBO_DEVICE TURBO_INLINE float __int_as_float_hip(int x) { + return __int_as_float(x); + } +#else + #include + using half_t = cutlass::half_t; + using bfloat16_t = cutlass::bfloat16_t; +#endif + +// Warp/Wave primitives +#if TURBO_PLATFORM_HIP + // RDNA3 uses wave32 + #define TURBO_WARP_SIZE 32 + #define TURBO_FULL_MASK 0xFFFFFFFFu + + TURBO_DEVICE TURBO_INLINE float warpReduceSum(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val += __shfl_xor(val, offset, TURBO_WARP_SIZE); + } + return val; + } + + TURBO_DEVICE TURBO_INLINE float warpReduceMax(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_xor(val, offset, TURBO_WARP_SIZE)); + } + return val; + } +#else + #define TURBO_WARP_SIZE 32 + #define TURBO_FULL_MASK 0xFFFFFFFFu + + TURBO_DEVICE TURBO_INLINE float warpReduceSum(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val += __shfl_xor_sync(TURBO_FULL_MASK, val, offset); + } + return val; + } + + TURBO_DEVICE TURBO_INLINE float warpReduceMax(float val) { + TURBO_PRAGMA_UNROLL + for (int offset = TURBO_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val = fmaxf(val, __shfl_xor_sync(TURBO_FULL_MASK, val, offset)); + } + return val; + } +#endif + +// Synchronization +#if TURBO_PLATFORM_HIP + #define __syncwarp() __syncthreads() +#endif + +// Kernel launch helper +template +TURBO_KERNEL void device_kernel_impl( + typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +void launch_kernel_unified( + typename Kernel::Params const& params, + dim3 grid_shape, + dim3 cta_shape, + size_t ShmSize, + turboStream_t stream = nullptr +) { + auto func = device_kernel_impl; + if (ShmSize >= 48 * 1024) { + TURBO_CHECK(turboFuncSetAttribute( + func, + turboFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + )); + } +#if TURBO_PLATFORM_HIP + hipLaunchKernelGGL(func, dim3(grid_shape), dim3(cta_shape), ShmSize, stream, params); +#else + func<<>>(params); +#endif + TURBO_CHECK(turboGetLastError()); +} + +// Numeric conversion helpers +namespace turbo { + +template +TURBO_DEVICE TURBO_INLINE To convert(From val) { + return static_cast(val); +} + +#if TURBO_PLATFORM_HIP +template <> +TURBO_DEVICE TURBO_INLINE int8_t convert(float val) { + // Round to nearest with clamping + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); + return static_cast(val); +} +#else +template <> +TURBO_DEVICE TURBO_INLINE int8_t convert(float val) { + return cutlass::NumericConverter()(val); +} +#endif + +} // namespace turbo + diff --git a/turbodiffusion/ops/common/store_hip.hpp b/turbodiffusion/ops/common/store_hip.hpp new file mode 100644 index 0000000..3c1b4e1 --- /dev/null +++ b/turbodiffusion/ops/common/store_hip.hpp @@ -0,0 +1,78 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once +#include "common/common_hip.hpp" + + +template < + class OutputDtype_, + int TileM_, + int TileN_, + int NumThrPerCta_, + bool IsEvenM, + bool IsEvenN, + bool Round = true, + bool SaveScale = true +> +class Saver { +public: + using OutputDtype = OutputDtype_; + + static constexpr int TileM = TileM_; + static constexpr int TileN = TileN_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int NumElementPerThread = TileM * TileN / NumThrPerCta; + static constexpr int NumThrPerRow = TileN / NumElementPerThread; + + static_assert(TileM * TileN % NumThrPerCta == 0); + static_assert(NumThrPerCta % TileM == 0); + + CUTLASS_DEVICE void + store(void *Optr, void *OSptr, void *reg, float scale_inv, int64_t m, int64_t n, int blk_m, int blk_n, int tid) { + int n_alignment = (n & 31) * sizeof(OutputDtype); + int thr_m_offset = tid / NumThrPerRow; + int thr_n_offset = (tid % NumThrPerRow) * NumElementPerThread; + void *cta_output_ptr = (void*)((OutputDtype*)Optr + blk_m * TileM * (Round ? cdiv(n, TileN) * TileN : n) + blk_n * TileN); + void *thr_output_ptr = (void*)((OutputDtype*)cta_output_ptr + thr_m_offset * (Round ? cdiv(n, TileN) * TileN : n) + thr_n_offset); + bool pred = IsEvenM ? true : thr_m_offset + blk_m * TileM < m; + int limit = IsEvenN ? NumElementPerThread : MIN(NumElementPerThread, n - (blk_n * TileN + thr_n_offset)); + if (n_alignment % 128 == 0) + _store(thr_output_ptr, reg, limit, pred); + else if (n_alignment % 64 == 0) + _store(thr_output_ptr, reg, limit, pred); + else + _store(thr_output_ptr, reg, limit, pred); + + if constexpr (SaveScale) { + if (tid == 0) { + *((float*)OSptr + blk_m * cdiv(n, TileN)+ blk_n) = scale_inv; + } + } + } + +private: + template + CUTLASS_DEVICE void + _store(void *thr_output_ptr, void *reg, int limit, bool pred) { + static constexpr int NumElementPerStore = sizeof(StoreDataType) / sizeof(OutputDtype); + if (pred) { + if constexpr (IsEven) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; i += NumElementPerStore) { + *(StoreDataType*)((OutputDtype*)thr_output_ptr + i) = *(StoreDataType*)((OutputDtype*)reg + i); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < limit; i += NumElementPerStore) { + if (limit - i > NumElementPerStore) + *(StoreDataType*)((OutputDtype*)thr_output_ptr + i) = *(StoreDataType*)((OutputDtype*)reg + i); + else { + for (int j = 0; j < limit - i; ++j) { + *((OutputDtype*)thr_output_ptr + i + j) = *((OutputDtype*)reg + i + j); + } + } + } + } + } + } + +}; diff --git a/turbodiffusion/ops/gemm/gemm_rocwmma.hip b/turbodiffusion/ops/gemm/gemm_rocwmma.hip new file mode 100644 index 0000000..53ca6d7 --- /dev/null +++ b/turbodiffusion/ops/gemm/gemm_rocwmma.hip @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + * + * rocWMMA-based GEMM for AMD ROCm GPUs. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "gemm/launch_rocwmma.hpp" + +void int8_gemm( + at::Tensor const& A, at::Tensor const& A_S, + at::Tensor const& B, at::Tensor const& B_S, + torch::Tensor& C +) { + static constexpr int swizzle_dir = 1; + static constexpr int swizzle_size_log = 5; + + int k = B.size(1); + int m = A.size(0); + int n = B.size(0); + + switch (C.scalar_type()) { + case torch::kHalf: { + int8_gemm_rocwmma<__half>( + (int8_t*)A.data_ptr(), A_S.data_ptr(), + (int8_t*)B.data_ptr(), B_S.data_ptr(), + (__half*)C.data_ptr(), + m, n, k, swizzle_dir, swizzle_size_log, + at::hip::getCurrentHIPStream().stream() + ); + break; + } + + case torch::kBFloat16: { + int8_gemm_rocwmma( + (int8_t*)A.data_ptr(), A_S.data_ptr(), + (int8_t*)B.data_ptr(), B_S.data_ptr(), + (hip_bfloat16*)C.data_ptr(), + m, n, k, swizzle_dir, swizzle_size_log, + at::hip::getCurrentHIPStream().stream() + ); + break; + } + + default: { + std::cerr << "Observing: " << C.scalar_type() << " for the output datatype which is invalid"; + throw std::runtime_error("Unsupported output data type for int8 gemm."); + } + } +} + +void register_gemm(pybind11::module_ &m) { + m.def("gemm_cuda", &int8_gemm); +} + diff --git a/turbodiffusion/ops/gemm/kernel_hip.hpp b/turbodiffusion/ops/gemm/kernel_hip.hpp new file mode 100644 index 0000000..9e2ec80 --- /dev/null +++ b/turbodiffusion/ops/gemm/kernel_hip.hpp @@ -0,0 +1,523 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include +#include "cute/tensor_hip.hpp" + +#include "common/common_hip.hpp" +#include "gemm/utils_hip.hpp" + +using namespace cute; + +template < + class OutputDtype_, + bool IsEvenM, + bool IsEvenN +> +struct GemmKernel { + using ElementA = int8_t; + using ElementB = int8_t; + using OutputDtype = OutputDtype_; + using AccumulatorDtype = int32_t; + static constexpr int BlockSize = 128; + static constexpr int TileM = 128; + static constexpr int TileN = 128; + static constexpr int TileK = 128; + static constexpr int Stage = 3; + static constexpr int EpiStage = 2; + + static_assert( + BlockSize % TileM == 0 + && BlockSize % TileN == 0 + && BlockSize % TileK == 0 + ); + + static constexpr int NumTilePerBlock = BlockSize / TileK; + + using SmemLayoutAtom = decltype( + composition( + Swizzle<3, 4, 3>{}, + make_layout( + make_shape(Int<8>{}, Int{}), + make_stride(Int{}, Int<1>{}) + ) + ) + ); + + using SmemLayoutA = decltype( + tile_to_shape( + SmemLayoutAtom{}, + make_shape(Int{}, Int{}, Int{}) + ) + ); + + using SmemLayoutB = decltype( + tile_to_shape( + SmemLayoutAtom{}, + make_shape(Int{}, Int{}, Int{}) + ) + ); + + using MmaOP = cute::SM80_16x8x32_S32S8S8S32_TN; + using TiledMma = decltype( + make_tiled_mma( + MMA_Atom>{}, + make_layout(make_shape( + _4{}, _2{}, _1{} + )), + make_tile(Int<64>{}, Int<32>{}, Int<32>{}) + ) + ); + + using G2SCopyAtomA = Copy_Atom>, ElementA>; + using G2SCopyAtomB = Copy_Atom>, ElementB>; + using G2STiledCopyA = decltype( + make_tiled_copy( + G2SCopyAtomA{}, + make_layout( + make_shape(Int<64>{}, Int<4>{}), + make_stride(Int<4>{}, Int<1>{}) + ), + make_layout(make_shape(Int<1>{}, Int<16>{})) + ) + ); + using G2STiledCopyB = decltype( + make_tiled_copy( + G2SCopyAtomB{}, + make_layout( + make_shape(Int<64>{}, Int<4>{}), + make_stride(Int<4>{}, Int<1>{}) + ), + make_layout(make_shape(Int<1>{}, Int<16>{})) + ) + ); + + using S2RCopyAtomA = Copy_Atom, ElementA>; + using S2RCopyAtomB = Copy_Atom, ElementB>; + using S2RTiledCopyA = decltype(make_tiled_copy_A(S2RCopyAtomA{}, TiledMma{})); + using S2RTiledCopyB = decltype(make_tiled_copy_B(S2RCopyAtomB{}, TiledMma{})); + + // epilogue + using SmemLayoutAtomD = decltype( + composition( + Swizzle<2, 3, 3>{}, + make_layout( + make_shape(Int<32>{}, Int<32>{}), + LayoutRight{} + ) + ) + ); + + using SmemLayoutD = decltype( + tile_to_shape( + SmemLayoutAtomD{}, + make_shape(Int<64>{}, Int<32>{}, Int{}) + ) + ); + + using R2SCopyAtomD = Copy_Atom>, OutputDtype>; + using R2STiledCopyD = decltype(make_tiled_copy_C(R2SCopyAtomD{}, TiledMma{})); + + using S2GCopyAtomD = Copy_Atom, OutputDtype>; + using S2GCopyD = decltype(make_tiled_copy( + S2GCopyAtomD{}, + make_layout(Shape<_64, _4>{}), + make_layout(Shape<_1, _8>{}) + )); + + using TileShape = decltype(make_shape(Int{}, Int{}, Int{})); + + struct SharedStorageAB: cute::aligned_struct<128> { + array_aligned, 128> smem_A; + array_aligned, 128> smem_B; + array_aligned smem_AS; + array_aligned smem_BS; + array_aligned smem_AF; + }; + + struct SharedStorageD: cute::aligned_struct<128> { + array_aligned> smem_D; + }; + + union SharedStorage { + SharedStorageAB storage_AB; + SharedStorageD storage_D; + }; + + + struct Params { + void const* Aptr; + void const* ASptr; + void const* Bptr; + void const* BSptr; + void* Dptr; + int64_t const m; + int64_t const n; + int64_t const k; + int const swizzle_dir; + int const swizzle_size; + }; + + using Arguments = Params; + + static constexpr int ThreadNum = size(TiledMma{}); + static constexpr int ShmSize = sizeof(SharedStorage); + static constexpr bool FastInt2Float = false; + + static bool can_implement(int64_t m, int64_t n, int64_t k) { + if (k % BlockSize != 0) return false; + if ((n * sizeof(OutputDtype)) % 16 != 0) + return false; + return true; + } + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(cdiv(m, TileM) * cdiv(n, TileN)); + } + + CUTLASS_HOST_DEVICE + static auto get_block_coord( + int64_t m_blocks, + int64_t n_blocks, + int const swizzle_dir, + int64_t const swizzle_size_log + ) { + int64_t blk_m; + int64_t blk_n; + + if (swizzle_dir == 1) + std::swap(m_blocks, n_blocks); + + if (swizzle_size_log == 0) { + blk_m = blockIdx.x % m_blocks; + blk_n = blockIdx.x / m_blocks; + } else { + int64_t group_size = n_blocks << swizzle_size_log; + int64_t num_groups = m_blocks >> swizzle_size_log; + int64_t group_idx = blockIdx.x / group_size; + int64_t local_idx = blockIdx.x % group_size; + if (group_idx == num_groups) { + blk_m = (num_groups << swizzle_size_log) + local_idx % (m_blocks - (num_groups << swizzle_size_log)); + blk_n = local_idx / (m_blocks - (num_groups << swizzle_size_log)); + } else { + blk_m = (local_idx & ((1LL << swizzle_size_log) - 1)) + (group_idx << swizzle_size_log); + blk_n = local_idx >> swizzle_size_log; + } + } + + if (swizzle_dir == 1) + std::swap(blk_m, blk_n); + + return make_coord(blk_m, blk_n); + } + + CUTLASS_DEVICE + void operator()( + Params const& params, char* smem_data + ) { + + SharedStorage& shared_storage = *reinterpret_cast(smem_data); + + auto t_idx = threadIdx.x; + + int64_t const m = params.m; + int64_t const n = params.n; + int64_t const k = params.k; + int const swizzle_dir = params.swizzle_dir; + int const swizzle_size = params.swizzle_size; + + Tensor A = make_tensor( + make_gmem_ptr(params.Aptr), + make_shape(m, k), + make_stride(k, _1{}) + ); + Tensor B = make_tensor( + make_gmem_ptr(params.Bptr), + make_shape(m, k), + make_stride(k, _1{}) + ); + Tensor AS = make_tensor( + make_gmem_ptr(params.ASptr), + make_shape(cdiv(m, BlockSize), cdiv(k, BlockSize)), + make_stride(cdiv(k, BlockSize), _1{}) + ); + Tensor BS = make_tensor( + make_gmem_ptr(params.BSptr), + make_shape(cdiv(n, BlockSize), cdiv(k, BlockSize)), + make_stride(cdiv(k, BlockSize), _1{}) + ); + Tensor D = make_tensor( + make_gmem_ptr(params.Dptr), + make_shape(m, n), + LayoutRight{} + ); + + auto [m_coord, n_coord] = get_block_coord( + cdiv(m, size<0>(TileShape{})), + cdiv(n, size<1>(TileShape{})), + swizzle_dir, swizzle_size + ); + + int32_t blk_m_coord = m_coord / (BlockSize / TileM); + int32_t blk_n_coord = n_coord / (BlockSize / TileN); + + // local tile + auto gA = local_tile(A, TileShape{}, make_coord(m_coord, n_coord, _), Step<_1, X, _1>{}); + auto gB = local_tile(B, TileShape{}, make_coord(m_coord, n_coord, _), Step{}); + auto gD = local_tile(D, TileShape{}, make_coord(m_coord, n_coord, _), Step<_1, _1, X>{}); + + // shared memory + Tensor sA = make_tensor( + make_smem_ptr(shared_storage.storage_AB.smem_A.data()), + SmemLayoutA{} + ); + + Tensor sB = make_tensor( + make_smem_ptr(shared_storage.storage_AB.smem_B.data()), + SmemLayoutB{} + ); + + // register + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(t_idx); + auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); + auto tDrC = thr_mma.partition_fragment_C(gD); // mma accumulator + auto tDrD = make_tensor_like(tDrC); // float accumulator + + if constexpr (FastInt2Float) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tDrC); ++i) + tDrC(i) = 0x4B400000; + } else { + clear(tDrC); + } + + clear(tDrD); + + + // global to shared copy + G2STiledCopyA g2s_tiled_copy_a; + auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(t_idx); + auto tAgA = g2s_thr_copy_a.partition_S(gA); + auto tAsA = g2s_thr_copy_a.partition_D(sA); + auto cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); + auto tAcA = g2s_thr_copy_a.partition_S(cA); + int const m_limit = m - TileM * m_coord; + int const n_limit = n - TileN * n_coord; + + G2STiledCopyB g2s_tiled_copy_b; + auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(t_idx); + auto tBgB = g2s_thr_copy_b.partition_S(gB); + auto tBsB = g2s_thr_copy_b.partition_D(sB); + auto cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); + auto tBcB = g2s_thr_copy_a.partition_S(cB); + + + // shared to register copy + S2RTiledCopyA s2r_tiled_copy_a; + auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(t_idx); + auto tCsA = s2r_thr_copy_a.partition_S(sA); + auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA); + + S2RTiledCopyB s2r_tiled_copy_b; + auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(t_idx); + auto tCsB = s2r_thr_copy_b.partition_S(sB); + auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB); + + // pipeline status + int64_t g2s_a_tile = 0; + int64_t g2s_b_tile = 0; + int g2s_a_smem = 0; + int g2s_b_smem = 0; + + int g2s_tile_in_block = 0; + int g2s_block = 0; // b block idx + + int s2r_a_smem = 0; + int s2r_b_smem = 0; + int s2r_tile_in_block = 0; + + int mma_block_a = 0; + int mma_block_b = 0; + + int ntile = k / TileK; + // load scale and fallback + // we assume all ptrs are 128bit aligned + // auto smem_fallback_A = raw_pointer_cast(make_smem_ptr(shared_storage.storage_AB.smem_AF.data())); + // auto smem_scale_A = raw_pointer_cast(make_smem_ptr(shared_storage.storage_AB.smem_AS.data())); + // auto smem_scale_B = raw_pointer_cast(make_smem_ptr(shared_storage.storage_AB.smem_BS.data())); + __syncthreads(); + + + int32_t fallbackA_load = 0; + int32_t fallbackA_mma = 0; + + // copy first Stage - 1 tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0, _i = min(Stage - 1, ntile); i < _i; ++i) { + if (g2s_b_tile < ntile) { + g2s_tile_in_block = (g2s_tile_in_block + 1) % NumTilePerBlock; + copy_AB(g2s_tiled_copy_a, tAgA, tAsA, tAcA, g2s_a_tile, g2s_a_smem, m_limit); + copy_AB(g2s_tiled_copy_b, tBgB, tBsB, tBcB, g2s_b_tile, g2s_b_smem, n_limit); + ++g2s_b_tile; + ++g2s_b_smem; + ++g2s_block; + g2s_a_tile = g2s_block * NumTilePerBlock; + ++g2s_a_smem; + } + cp_async_fence(); + } + + constexpr int nk = size<2>(tCrA); + float scale_a = AS(blk_m_coord, 0); + float scale_b = BS(blk_n_coord, 0); + + CUTLASS_PRAGMA_NO_UNROLL + for (int64_t mma_b_tile = 0; mma_b_tile < ntile; ++mma_b_tile) { + s2r_tile_in_block = (s2r_tile_in_block + 1) % NumTilePerBlock; + cp_async_wait(); + __syncthreads(); + + // do mma first + CUTLASS_PRAGMA_UNROLL + for (int ik = 0; ik < nk; ++ik) { + cute::copy(s2r_tiled_copy_a, tCsA(_, _, ik, s2r_a_smem), + tCrA_view(_, _, ik)); + cute::copy(s2r_tiled_copy_b, tCsB(_, _, ik, s2r_b_smem), + tCrB_view(_, _, ik)); + cute::gemm(tiled_mma, tDrC, tCrA(_, _, ik), tCrB(_, _, ik), tDrC); + } + + // a s2r increase anyway + s2r_a_smem = (s2r_a_smem + 1) % Stage; + + // get next s2r b tile int64_t + // end of a block + + // dequant first + dequant( + tDrC.data(), tDrD.data(), scale_a * scale_b + ); + + s2r_b_smem = (s2r_b_smem + 1) % Stage; + // b advance + ++mma_block_b; + if (mma_block_b < size<1>(BS)) scale_b = BS(blk_n_coord, mma_block_b); + mma_block_a = mma_block_b; + if (mma_block_a < size<1>(AS)) scale_a = AS(blk_m_coord, mma_block_a); + + // load next stage + if (g2s_b_tile < ntile) { + g2s_tile_in_block = (g2s_tile_in_block + 1) % NumTilePerBlock; + copy_AB(g2s_tiled_copy_a, tAgA, tAsA, tAcA, g2s_a_tile, g2s_a_smem, m_limit); + copy_AB(g2s_tiled_copy_b, tBgB, tBsB, tBcB, g2s_b_tile, g2s_b_smem, n_limit); + ++g2s_b_tile; + g2s_b_smem = (g2s_b_smem + 1) % Stage; + ++g2s_block; + g2s_a_tile = g2s_block * NumTilePerBlock; + g2s_a_smem = (g2s_a_smem + 1) % Stage; + } + cp_async_fence(); + } + + + // epilogue + + Tensor sD = make_tensor( + make_smem_ptr(shared_storage.storage_D.smem_D.data()), + SmemLayoutD{} + ); + + R2STiledCopyD r2s_tiled_copy_d; + auto r2s_thr_copy_d = r2s_tiled_copy_d.get_slice(t_idx); + auto tDrD_r2s = r2s_thr_copy_d.retile_S(tDrD); + auto tDsD_r2s = r2s_thr_copy_d.partition_D(sD); + + S2GCopyD s2g_tiled_copy_d; + auto s2g_thr_copy_d = s2g_tiled_copy_d.get_slice(t_idx); + auto tDsD_s2g = s2g_thr_copy_d.partition_S(sD); + auto tDgD_s2g = s2g_thr_copy_d.partition_D(gD); + Tensor cD = make_identity_tensor(make_shape(Int{}, Int{})); + auto tDcD_s2g = s2g_thr_copy_d.partition_D(cD); + + auto tDgD_s2gx = group_modes<1, 3>(tDgD_s2g); // (CPY_, CPY_MN) + auto tDrD_r2sx = group_modes<1, 3>(tDrD_r2s); // (CPY_, CPY_MN) + auto tDcD_s2gx = group_modes<1, 3>(tDcD_s2g); + + int32_t step = size<3>(tDsD_r2s); // pipe + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < size<1>(tDrD_r2sx); i += step) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) { + if constexpr (std::is_same::value) { + cute::copy(r2s_tiled_copy_d, tDrD_r2sx(_, i + j), tDsD_r2s(_, 0, 0, j)); + } else { + auto t = make_tensor_like(tDrD_r2sx(_, i + j)); + cute::copy(tDrD_r2sx(_, i + j), t); + cute::copy(r2s_tiled_copy_d, t, tDsD_r2s(_, 0, 0, j)); + } + } + + __syncthreads(); + + // shm -> global + if constexpr (IsEvenM && IsEvenN) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } else if constexpr (IsEvenN) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) { + if (get<0>(tDcD_s2gx(0, i + j)) < m_limit) + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } + } else if constexpr (IsEvenM) { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) + if (get<1>(tDcD_s2gx(size<0>(tDsD_s2g) - 1, i + j)) < n_limit) { + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } else { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<0>(tDsD_s2g); ++k) + if (get<1>(tDcD_s2gx(k, i + j)) < n_limit) + tDgD_s2gx(k, i + j) = tDsD_s2g(k, 0, 0, j); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int32_t j = 0; j < step; ++j) + if (get<0>(tDcD_s2gx(0, i + j)) < m_limit) { + if (get<1>(tDcD_s2gx(size<0>(tDsD_s2g) - 1, i + j)) < n_limit) { + cute::copy(s2g_tiled_copy_d, tDsD_s2g(_, 0, 0, j), tDgD_s2gx(_, i + j)); + } else { + for (int32_t k = 0; k < size<0>(tDsD_s2g); ++k) + if (get<1>(tDcD_s2gx(k, i + j)) < n_limit) + tDgD_s2gx(k, i + j) = tDsD_s2g(k, 0, 0, j); + } + } + } + __syncthreads(); + } + } + +}; + \ No newline at end of file diff --git a/turbodiffusion/ops/gemm/kernel_rocwmma.hpp b/turbodiffusion/ops/gemm/kernel_rocwmma.hpp new file mode 100644 index 0000000..ee06c67 --- /dev/null +++ b/turbodiffusion/ops/gemm/kernel_rocwmma.hpp @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + * + * rocWMMA-based GEMM kernel for AMD RDNA3 GPUs. + * This kernel performs int8 GEMM with per-block quantization scaling. + * + * Based on rocWMMA from https://github.com/ROCm/rocm-libraries/tree/develop/projects/rocwmma + */ + +#pragma once + +#include + +// Undefine the no-half-conversion macros that PyTorch sets +// rocWMMA needs these conversions to work properly +#ifdef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_OPERATORS__ +#endif +#ifdef __HIP_NO_HALF_CONVERSIONS__ +#undef __HIP_NO_HALF_CONVERSIONS__ +#endif + +#include +#include +#include + +#include "common/platform.hpp" + +using namespace rocwmma; + +// Helper for float to output type conversion +template +TURBO_DEVICE TURBO_INLINE T float_to_output(float val); + +template <> +TURBO_DEVICE TURBO_INLINE __half float_to_output<__half>(float val) { + return __float2half(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE hip_bfloat16 float_to_output(float val) { + return hip_bfloat16(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE float float_to_output(float val) { + return val; +} + +template <> +TURBO_DEVICE TURBO_INLINE int32_t float_to_output(float val) { + return static_cast(val); +} + +// RDNA3 (gfx11) specific parameters +// Wave size: 32, Block sizes: 16x16x16 +namespace rdna3 { + constexpr uint32_t ROCWMMA_M = 16u; + constexpr uint32_t ROCWMMA_N = 16u; + constexpr uint32_t ROCWMMA_K = 16u; + constexpr uint32_t WAVE_SIZE = 32u; + constexpr uint32_t QUANT_BLOCK = 128u; // Quantization block size +} + +template < + class OutputDtype_, + bool IsEvenM, + bool IsEvenN +> +struct GemmKernelRocWMMA { + using ElementA = int8_t; + using ElementB = int8_t; + using OutputDtype = OutputDtype_; + using AccumulatorDtype = int32_t; + using ComputeDtype = int32_t; // MMA accumulator type + + // Tile sizes + static constexpr int TileM = rdna3::ROCWMMA_M; + static constexpr int TileN = rdna3::ROCWMMA_N; + static constexpr int TileK = rdna3::ROCWMMA_K; + static constexpr int BlockSize = rdna3::QUANT_BLOCK; + static constexpr int WaveSize = rdna3::WAVE_SIZE; + + // Warp tile: how many MMA tiles each wave computes + static constexpr int WarpTileM = 2; // 2 tiles in M direction = 32 + static constexpr int WarpTileN = 2; // 2 tiles in N direction = 32 + + // Thread block configuration + static constexpr int TBlockX = 128; // 4 waves + static constexpr int TBlockY = 1; + static constexpr int NumWarps = TBlockX / WaveSize; // 4 waves + + // Macro tile computed by entire thread block + static constexpr int MacroTileM = NumWarps * WarpTileM * TileM; // 4 * 2 * 16 = 128 + static constexpr int MacroTileN = TBlockY * WarpTileN * TileN; // 1 * 2 * 16 = 32 + + // Fragment types - using row_major for A and col_major for B (NT layout) + using FragA = fragment; + using FragB = fragment; + using FragAcc = fragment; + + struct Params { + void const* Aptr; + void const* ASptr; + void const* Bptr; + void const* BSptr; + void* Dptr; + int64_t const m; + int64_t const n; + int64_t const k; + int const swizzle_dir; + int const swizzle_size; + }; + + using Arguments = Params; + + static constexpr int ThreadNum = TBlockX * TBlockY; + // Shared memory for storing fragments: each wave needs TileM*TileN*sizeof(float) per tile + // With WarpTileM=2, WarpTileN=2, and NumWarps=4 waves: + // Size = NumWarps * WarpTileM * WarpTileN * TileM * TileN * sizeof(float) + // = 4 * 2 * 2 * 16 * 16 * 4 = 16384 bytes + static constexpr int ShmSize = NumWarps * WarpTileM * WarpTileN * TileM * TileN * sizeof(float); + static constexpr int MaxThreadsPerBlock = ThreadNum; + static constexpr int MinBlocksPerMultiprocessor = 1; + + static bool can_implement(int64_t m, int64_t n, int64_t k) { + if (k % BlockSize != 0) return false; + if ((n * sizeof(OutputDtype)) % 16 != 0) return false; + return true; + } + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + TURBO_HOST_DEVICE + static int64_t cdiv(int64_t a, int64_t b) { + return (a + b - 1) / b; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + int64_t grid_m = cdiv(m, MacroTileM); + int64_t grid_n = cdiv(n, MacroTileN); + return dim3(grid_m * grid_n); + } + + TURBO_DEVICE + void operator()(Params const& params, char* smem_data) { + int64_t const m = params.m; + int64_t const n = params.n; + int64_t const k = params.k; + + // Wave and lane indices + int waveId = threadIdx.x / WaveSize; + int laneId = threadIdx.x % WaveSize; + + // Grid dimensions + int64_t grid_m = cdiv(m, MacroTileM); + int64_t grid_n = cdiv(n, MacroTileN); + + // Block coordinates (linear to 2D) + int64_t block_m = blockIdx.x % grid_m; + int64_t block_n = blockIdx.x / grid_m; + + // Base coordinates for this wave's output tiles + int64_t wave_m_base = block_m * MacroTileM + waveId * WarpTileM * TileM; + int64_t wave_n_base = block_n * MacroTileN; + + // Pointers + ElementA const* A = reinterpret_cast(params.Aptr); + ElementB const* B = reinterpret_cast(params.Bptr); + float const* AS = reinterpret_cast(params.ASptr); + float const* BS = reinterpret_cast(params.BSptr); + + // Number of quantization blocks in K dimension + int64_t num_quant_blocks_k = k / BlockSize; + + // Float accumulators for dequantized results + float floatAcc[WarpTileM][WarpTileN][FragAcc::num_elements]; + + // Initialize accumulators + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + TURBO_PRAGMA_UNROLL + for (int i = 0; i < FragAcc::num_elements; ++i) { + floatAcc[wm][wn][i] = 0.0f; + } + } + } + + // Process each quantization block + for (int64_t qb = 0; qb < num_quant_blocks_k; ++qb) { + int64_t k_start = qb * BlockSize; + int64_t k_end = k_start + BlockSize; + + // Integer accumulators for this quant block + FragAcc fragAcc[WarpTileM][WarpTileN]; + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + fill_fragment(fragAcc[wm][wn], static_cast(0)); + } + } + + // K-loop within quantization block + for (int64_t kk = k_start; kk < k_end; kk += TileK) { + // Load and compute for each tile in warp tile + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + int64_t tile_m = wave_m_base + wm * TileM; + + FragA fragA; + if (tile_m < m) { + // A is row-major: A[m, k] + load_matrix_sync(fragA, A + tile_m * k + kk, k); + } else { + fill_fragment(fragA, static_cast(0)); + } + + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + int64_t tile_n = wave_n_base + wn * TileN; + + FragB fragB; + if (tile_n < n) { + // B is stored as [N, K] in col-major (K changes fastest when reading B[n, :]) + load_matrix_sync(fragB, B + tile_n * k + kk, k); + } else { + fill_fragment(fragB, static_cast(0)); + } + + // Matrix multiply-accumulate + mma_sync(fragAcc[wm][wn], fragA, fragB, fragAcc[wm][wn]); + } + } + } + + // Dequantize this block's contribution + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + int64_t tile_m = wave_m_base + wm * TileM; + int64_t qblock_m = tile_m / BlockSize; + + // Get scale for A + float scale_a = 1.0f; + if (qblock_m < cdiv(m, BlockSize) && qb < num_quant_blocks_k) { + scale_a = AS[qblock_m * num_quant_blocks_k + qb]; + } + + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + int64_t tile_n = wave_n_base + wn * TileN; + int64_t qblock_n = tile_n / BlockSize; + + // Get scale for B + float scale_b = 1.0f; + if (qblock_n < cdiv(n, BlockSize) && qb < num_quant_blocks_k) { + scale_b = BS[qblock_n * num_quant_blocks_k + qb]; + } + + float scale = scale_a * scale_b; + + // Accumulate dequantized values + TURBO_PRAGMA_UNROLL + for (int i = 0; i < FragAcc::num_elements; ++i) { + floatAcc[wm][wn][i] += static_cast(fragAcc[wm][wn].x[i]) * scale; + } + } + } + } + + // Store final results using store_matrix_sync to shared memory temp buffer + // This ensures correct fragment layout interpretation + OutputDtype* D = reinterpret_cast(params.Dptr); + + // Each wave gets its own section of shared memory + float* smem_temp = reinterpret_cast(smem_data); + float* wave_smem = smem_temp + waveId * WarpTileM * WarpTileN * TileM * TileN; + + TURBO_PRAGMA_UNROLL + for (int wm = 0; wm < WarpTileM; ++wm) { + int64_t tile_m = wave_m_base + wm * TileM; + + TURBO_PRAGMA_UNROLL + for (int wn = 0; wn < WarpTileN; ++wn) { + int64_t tile_n = wave_n_base + wn * TileN; + + // Create a float fragment from the accumulated values + fragment fragFloat; + TURBO_PRAGMA_UNROLL + for (int i = 0; i < FragAcc::num_elements; ++i) { + fragFloat.x[i] = floatAcc[wm][wn][i]; + } + + // Store to wave's temp buffer using rocWMMA (row-major layout) + float* tile_buf = wave_smem + (wm * WarpTileN + wn) * TileM * TileN; + store_matrix_sync(tile_buf, fragFloat, TileN, mem_row_major); + + __syncthreads(); + + // Now read from tile_buf with linear indexing + for (int e = laneId; e < TileM * TileN; e += WaveSize) { + int local_row = e / TileN; + int local_col = e % TileN; + + int64_t global_row = tile_m + local_row; + int64_t global_col = tile_n + local_col; + + if (global_row < m && global_col < n) { + D[global_row * n + global_col] = float_to_output(tile_buf[e]); + } + } + + __syncthreads(); + } + } + } +}; diff --git a/turbodiffusion/ops/gemm/launch_hip.hpp b/turbodiffusion/ops/gemm/launch_hip.hpp new file mode 100644 index 0000000..ab4f332 --- /dev/null +++ b/turbodiffusion/ops/gemm/launch_hip.hpp @@ -0,0 +1,66 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include "common/common_hip.hpp" +#include "common/launch_hip.hpp" +#include "gemm/kernel_hip.hpp" + + +template +bool int8_gemm_( + int8_t const *Aptr, float const *ASptr, + int8_t const *Bptr, float const *BSptr, + OutputDtype* Dptr, int64_t m, int64_t n, int64_t k, + int swizzle_dir = 1, int swizzle_size_log = 0, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % 128 == 0, IsEvenM, [&] { + BOOL_SWITCH(n % 128 == 0, IsEvenN, [&] { + using Kernel = GemmKernel; + if (!Kernel::can_implement(m, n, k)) + return false; + using Args = typename Kernel::Arguments; + Args args { + (void*)Aptr, (void*)ASptr, + (void*)Bptr, (void*)BSptr, (void*)Dptr, + m, n, k, swizzle_dir, + swizzle_size_log + }; + + auto params = Kernel::to_underlying_arguments(args); + + static constexpr size_t ShmSize = Kernel::ShmSize; + dim3 grid_shape = Kernel::get_grid_size(m, n); + dim3 block_shape = dim3(Kernel::ThreadNum); + auto func = device_kernel; + if (ShmSize >= 48 * 1024) { + hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + ); + } + hipLaunchKernelGGL(( func), dim3(grid_shape), dim3(block_shape), ShmSize, stream, + params + ); + return true; + }); + }); + return true; +} diff --git a/turbodiffusion/ops/gemm/launch_rocwmma.hpp b/turbodiffusion/ops/gemm/launch_rocwmma.hpp new file mode 100644 index 0000000..9b975a7 --- /dev/null +++ b/turbodiffusion/ops/gemm/launch_rocwmma.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * rocWMMA GEMM kernel launch wrapper for AMD GPUs. + */ + +#pragma once + +#include +#include "common/platform.hpp" +#include "gemm/kernel_rocwmma.hpp" + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +template +__global__ void rocwmma_gemm_kernel( + typename Kernel::Params const params +) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template +bool int8_gemm_rocwmma( + int8_t const* Aptr, float const* ASptr, + int8_t const* Bptr, float const* BSptr, + OutputDtype* Dptr, int64_t m, int64_t n, int64_t k, + int swizzle_dir = 1, int swizzle_size_log = 0, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % 128 == 0, IsEvenM, [&] { + BOOL_SWITCH(n % 128 == 0, IsEvenN, [&] { + using Kernel = GemmKernelRocWMMA; + + if (!Kernel::can_implement(m, n, k)) { + return false; + } + + using Args = typename Kernel::Arguments; + Args args{ + (void*)Aptr, (void*)ASptr, + (void*)Bptr, (void*)BSptr, (void*)Dptr, + m, n, k, swizzle_dir, + swizzle_size_log + }; + + auto params = Kernel::to_underlying_arguments(args); + + static constexpr size_t ShmSize = Kernel::ShmSize; + dim3 grid_shape = Kernel::get_grid_size(m, n); + dim3 block_shape = dim3(Kernel::ThreadNum); + + auto func = rocwmma_gemm_kernel; + if (ShmSize >= 48 * 1024) { + hipFuncSetAttribute( + func, + hipFuncAttributeMaxDynamicSharedMemorySize, + ShmSize + ); + } + + hipLaunchKernelGGL(func, grid_shape, block_shape, ShmSize, stream, params); + return true; + }); + }); + return true; +} + diff --git a/turbodiffusion/ops/gemm/utils_hip.hpp b/turbodiffusion/ops/gemm/utils_hip.hpp new file mode 100644 index 0000000..76356b4 --- /dev/null +++ b/turbodiffusion/ops/gemm/utils_hip.hpp @@ -0,0 +1,130 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include +#include "cute/tensor_hip.hpp" + +template < + bool IsEven, + class TiledCopy, + class SrcTensor, + class DstTensor, + class PrdTensor +> +CUTLASS_DEVICE void +copy_AB( + TiledCopy const& _copy, + SrcTensor const &S, + DstTensor &D, + PrdTensor const &ID, + const int64_t &i_read, + const int64_t &i_write, + const int64_t &limit +) { + using namespace cute; + if constexpr (IsEven) + cute::copy(_copy, S(_, _, _, i_read), D(_, _, _, i_write)); + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(ID); ++i) + if (get<0>(ID(0, i, 0)) < limit) + cute::copy(_copy, S(_, i, _, i_read), D(_, i, _, i_write)); + } +} + +template +CUTLASS_DEVICE void copy_async( + void const* gmem_src, + void* smem_dst +) { + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_dst));; + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_src), + "n"(N)); +} + +template +CUTLASS_DEVICE void copy_aligned(const void* src, void* dst, size_t N, int64_t thread_idx) { + static constexpr int NumElementPerLoad = sizeof(LoadType) / sizeof(T); + for (int64_t i = thread_idx * NumElementPerLoad; i < N; i += NumElementPerLoad * NumThreads) { + if (i + NumElementPerLoad <= N) { + copy_async( + (void*)((T*)src + i), + (void*)((T*)dst + i) + ); + } else { + for (int64_t j = 0; j < N - i; ++j) + copy_async( + (void*)((T*)src + i + j), + (void*)((T*)dst + i + j) + ); + } + } +} + +template +CUTLASS_DEVICE void g2s_vector_copy(const void* src, void* dst, size_t N, int64_t thread_idx) { + + uintptr_t src_addr = reinterpret_cast(src); + + if (src_addr % 16 == 0) { + copy_aligned(src, dst, N, thread_idx); + } else if (src_addr % 8 == 0) { + copy_aligned(src, dst, N, thread_idx); + } else if (src_addr % 4 == 0) { + copy_aligned(src, dst, N, thread_idx); + } else { + assert(0); + } + if constexpr (Commit) { + asm volatile("cp.async.commit_group;\n" ::); + } + if constexpr (Wait) { + asm volatile("cp.async.wait_all;\n" ::); + } +} + +template +CUTLASS_DEVICE +static void dequant( + T* mma_accum_ptr, + float* float_accum_ptr, + float scale +) { + static int const ic = 0x4B400000; + if constexpr (FastInt2Float && std::is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < N; ++i) { + *(float_accum_ptr + i) += (__int_as_float(*(mma_accum_ptr + i)) - __int_as_float(ic)) * scale; + *(mma_accum_ptr + i) = ic; + } + } else if constexpr (std::is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < N; ++i) { + *(float_accum_ptr + i) += __int2float_rn(*(mma_accum_ptr + i)) * scale; + *(mma_accum_ptr + i) = 0; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < N; ++i) { + *(float_accum_ptr + i) += (*(mma_accum_ptr + i)) * scale; + *(mma_accum_ptr + i) = 0; + } + } +} \ No newline at end of file diff --git a/turbodiffusion/ops/norm/layernorm.hip b/turbodiffusion/ops/norm/layernorm.hip new file mode 100644 index 0000000..83c78d9 --- /dev/null +++ b/turbodiffusion/ops/norm/layernorm.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * HIP/ROCm LayerNorm kernel. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "norm/layernorm_hip.hpp" + +auto layer_norm( + at::Tensor const Input, + float eps, + std::optional W, + std::optional const B, + std::optional Output +) { + using ElementIn = float; + using ElementOut = float; + using ElementWeight = float; + + int64_t const m = Input.size(0); + int64_t const n = Input.size(1); + torch::Device const input_device = Input.device(); + + if (!Output.has_value()) { + Output.emplace( + torch::empty( + {m, n}, + torch::TensorOptions().device(input_device).dtype(torch::kFloat32) + ) + ); + } + + void *Iptr = Input.data_ptr(); + void *Wptr = W.has_value() ? W.value().data_ptr() : nullptr; + void *Bptr = B.has_value() ? B.value().data_ptr() : nullptr; + void *Optr = Output.value().data_ptr(); + + BOOL_SWITCH(B.has_value(), BIAS, [&]{ + BOOL_SWITCH(W.has_value(), AFFINE, [&]{ + CONFIG_SWITCH(n, [&]{ + layernorm< + ElementIn, ElementOut, ElementWeight, + AFFINE, BIAS, + MAX_HIDDEN_SIZE, NUM_THR_PER_CTA + >( + Iptr, Wptr, Bptr, + Optr, eps, m, n, + at::hip::getCurrentHIPStream().stream() + ); + }); + }); + }); + + return Output; +} + +void register_layer_norm(pybind11::module_ &m) { + m.def("layer_norm_cuda", &layer_norm); +} + diff --git a/turbodiffusion/ops/norm/layernorm_hip.hpp b/turbodiffusion/ops/norm/layernorm_hip.hpp new file mode 100644 index 0000000..5c5e0f5 --- /dev/null +++ b/turbodiffusion/ops/norm/layernorm_hip.hpp @@ -0,0 +1,221 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#include "common/common_hip.hpp" +#include "common/load.hpp" +#include "common/store_hip.hpp" +#include "common/launch_hip.hpp" + +// Helper for output type conversion +namespace turbo_layernorm { +template +__device__ __forceinline__ T from_float(float val) { + return static_cast(val); +} +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float val) { + return hip_bfloat16(val); +} +} // namespace turbo_layernorm + + +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + bool Affine_, + bool Bias_, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class LayerNorm { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + static constexpr bool Affine = Affine_; + static constexpr bool Bias = Bias_; + + static constexpr size_t ShmSize = 32; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const *Iptr; + void const *Wptr; + void const *Bptr; + void *Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *shared_data) { + int const blk_m = blockIdx.x; + int const blk_n = 1; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + + // load + Loader loader; + loader.load(params.Iptr, x, params.m, params.n, blk_m, 0, tidx); + + // mean reduction + float u = _reduce_sum(x, shared_data) / params.n; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] -= u; + + __syncthreads(); + // var reduction + float v = sqrtf(_reduce_square(x, shared_data) / params.n + params.eps); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] /= v; + + if constexpr (Affine) { + // load weight + Loader weight_loader; + float w[NumElementPerThread]; + weight_loader.load(params.Wptr, w, 1, params.n, 0, 0, tidx); + if constexpr (Bias) { + float b[NumElementPerThread]; + weight_loader.load(params.Bptr, b, 1, params.n, 0, 0, tidx); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] = x[i] * w[i] + b[i]; + } else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] = x[i] * w[i]; + } + } + + // save y + { + Saver saver; + if constexpr (std::is_same_v) { + saver.store(params.Optr, nullptr, x, 0, params.m, params.n, blk_m, 0, tidx); + } else { + OutputDtype tmp[NumElementPerThread]; + for (int i = 0; i < NumElementPerThread; ++i) + tmp[i] = turbo_layernorm::from_float(x[i]); + saver.store(params.Optr, nullptr, tmp, 0, params.m, params.n, blk_m, 0, tidx); + } + } + + } + +private: + CUTLASS_DEVICE + float _reduce_square(float *reg, char *shared_data) { + // thread + float sum_square = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + sum_square += reg[i] * reg[i]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum_square += __shfl_down(sum_square, i, 32); + } + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum_square); + } + + __syncthreads(); + sum_square = *(float*)shared_data; + return sum_square; + } + + CUTLASS_DEVICE + float _reduce_sum(float *reg, char *shared_data) { + // thread + float sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + sum += reg[i]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum += __shfl_down(sum, i, 32); + } + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum); + } + + __syncthreads(); + sum = *(float*)shared_data; + return sum; + } +}; + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + bool Affine, + bool Bias, + int MaxHiddenSize, + int NumThrPerCta +> +bool layernorm( + void const *Iptr, void const *Wptr, void const *Bptr, + void *Optr, float eps, int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = LayerNorm< + InputDtype, OutputDtype, WeightDtype, + Affine, Bias, + MaxHiddenSize, NumThrPerCta, + IsEven>; + using Arguments = typename Kernel::Arguments; + Arguments args = { + Iptr, Wptr, Bptr, Optr, + eps, m, n + }; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + launch_kernel(params, grid_shape, cta_shape, ShmSize, stream); + }); + return true; +} \ No newline at end of file diff --git a/turbodiffusion/ops/norm/norm_rocm.hpp b/turbodiffusion/ops/norm/norm_rocm.hpp new file mode 100644 index 0000000..1dd636c --- /dev/null +++ b/turbodiffusion/ops/norm/norm_rocm.hpp @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Normalization kernels for AMD GPUs using HIP. + */ + +#pragma once + +#include +#include +#include "common/platform.hpp" + +TURBO_HOST_DEVICE inline int64_t cdiv_norm(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +#define MIN_NORM(a, b) ((a) > (b) ? (b) : (a)) +#define MAX_NORM(a, b) ((a) > (b) ? (a) : (b)) + +#define BOOL_SWITCH_NORM(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +// RMSNorm Kernel +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class RMSNormHIP { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + static constexpr size_t ShmSize = 32; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const* Iptr; + void const* Wptr; + void* Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + TURBO_DEVICE + void operator()(Params const& params, char* shared_data) { + int blk_m = blockIdx.x; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + float w[NumElementPerThread]; + + // Load input + load_input(params.Iptr, x, params.m, params.n, blk_m, tidx); + + // RMS reduction + float rms = sqrtf(reduce_square(x, shared_data) / params.n + params.eps); + + // Load weight + load_weight(params.Wptr, w, params.n, tidx); + + // Normalize + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + x[i] = w[i] * x[i] / rms; + } + + // Store output + store_output(params.Optr, x, params.m, params.n, blk_m, tidx); + } + +private: + TURBO_DEVICE + void load_input(void const* input_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + InputDtype const* input = reinterpret_cast(input_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(input[offset + i]); + } else { + reg[i] = 0.0f; + } + } + } + + TURBO_DEVICE + void load_weight(void const* weight_ptr, float* reg, int64_t n, int tidx) { + if (weight_ptr == nullptr) { + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + reg[i] = 1.0f; + } + return; + } + + WeightDtype const* weight = reinterpret_cast(weight_ptr); + int64_t offset = tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(weight[offset + i]); + } else { + reg[i] = 1.0f; + } + } + } + + TURBO_DEVICE + void store_output(void* output_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + OutputDtype* output = reinterpret_cast(output_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + output[offset + i] = static_cast(reg[i]); + } + } + } + + TURBO_DEVICE + float reduce_square(float* reg, char* shared_data) { + float sum_square = 0; + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + sum_square += reg[i] * reg[i]; + } + + TURBO_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum_square += __shfl_down(sum_square, i, 32); + } + + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum_square); + } + + __syncthreads(); + sum_square = *(float*)shared_data; + return sum_square; + } +}; + +// LayerNorm Kernel +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + bool Affine, + bool Bias, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class LayerNormHIP { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + static constexpr size_t ShmSize = 64; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const* Iptr; + void const* Wptr; + void const* Bptr; + void* Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + TURBO_DEVICE + void operator()(Params const& params, char* shared_data) { + int blk_m = blockIdx.x; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + float w[NumElementPerThread]; + float b[NumElementPerThread]; + + // Load input + load_input(params.Iptr, x, params.m, params.n, blk_m, tidx); + + // Compute mean and variance + float mean, var; + reduce_mean_var(x, shared_data, params.n, mean, var); + float rstd = rsqrtf(var + params.eps); + + // Load weight and bias + if constexpr (Affine) { + load_weight(params.Wptr, w, params.n, tidx); + } + if constexpr (Bias) { + load_weight(params.Bptr, b, params.n, tidx); + } + + // Normalize + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + float val = (x[i] - mean) * rstd; + if constexpr (Affine) { + val *= w[i]; + } + if constexpr (Bias) { + val += b[i]; + } + x[i] = val; + } + + // Store output + store_output(params.Optr, x, params.m, params.n, blk_m, tidx); + } + +private: + TURBO_DEVICE + void load_input(void const* input_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + InputDtype const* input = reinterpret_cast(input_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(input[offset + i]); + } else { + reg[i] = 0.0f; + } + } + } + + TURBO_DEVICE + void load_weight(void const* weight_ptr, float* reg, int64_t n, int tidx) { + if (weight_ptr == nullptr) { + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + reg[i] = 1.0f; + } + return; + } + + WeightDtype const* weight = reinterpret_cast(weight_ptr); + int64_t offset = tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + reg[i] = static_cast(weight[offset + i]); + } else { + reg[i] = 0.0f; + } + } + } + + TURBO_DEVICE + void store_output(void* output_ptr, float* reg, int64_t m, int64_t n, int blk_m, int tidx) { + OutputDtype* output = reinterpret_cast(output_ptr); + int64_t offset = blk_m * n + tidx * NumElementPerThread; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEven || (tidx * NumElementPerThread + i) < n) { + output[offset + i] = static_cast(reg[i]); + } + } + } + + TURBO_DEVICE + void reduce_mean_var(float* reg, char* shared_data, int64_t n, float& mean, float& var) { + float sum = 0; + float sum_sq = 0; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + sum += reg[i]; + sum_sq += reg[i] * reg[i]; + } + + // Warp reduction + TURBO_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum += __shfl_down(sum, i, 32); + sum_sq += __shfl_down(sum_sq, i, 32); + } + + float* smem = (float*)shared_data; + if (threadIdx.x == 0) { + smem[0] = 0; + smem[1] = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd(&smem[0], sum); + atomicAdd(&smem[1], sum_sq); + } + + __syncthreads(); + + sum = smem[0]; + sum_sq = smem[1]; + + mean = sum / n; + var = sum_sq / n - mean * mean; + } +}; + +// Kernel launchers +template +__global__ void norm_kernel_hip(typename Kernel::Params const params) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + int MaxHiddenSize, + int NumThrPerCta +> +bool rmsnorm_hip( + void const* Iptr, void const* Wptr, + void* Optr, float eps, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH_NORM(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = RMSNormHIP; + using Arguments = typename Kernel::Arguments; + + Arguments args = {Iptr, Wptr, Optr, eps, m, n}; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + + hipLaunchKernelGGL(norm_kernel_hip, grid_shape, cta_shape, ShmSize, stream, params); + }); + return true; +} + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + bool Affine, + bool Bias, + int MaxHiddenSize, + int NumThrPerCta +> +bool layernorm_hip( + void const* Iptr, void const* Wptr, void const* Bptr, + void* Optr, float eps, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH_NORM(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = LayerNormHIP; + using Arguments = typename Kernel::Arguments; + + Arguments args = {Iptr, Wptr, Bptr, Optr, eps, m, n}; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + + hipLaunchKernelGGL(norm_kernel_hip, grid_shape, cta_shape, ShmSize, stream, params); + }); + return true; +} + diff --git a/turbodiffusion/ops/norm/rmsnorm.hip b/turbodiffusion/ops/norm/rmsnorm.hip new file mode 100644 index 0000000..fc8af49 --- /dev/null +++ b/turbodiffusion/ops/norm/rmsnorm.hip @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * HIP/ROCm RMSNorm kernel. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "norm/rmsnorm_hip.hpp" + +auto rms_norm( + at::Tensor const& Input, + float eps, + const std::optional& Weight, + std::optional& Output +) { + using ElementIn = float; + using ElementOut = float; + using ElementWeight = float; + + int64_t const m = Input.size(0); + int64_t const n = Input.size(1); + torch::Device const input_device = Input.device(); + + if (!Output.has_value()) { + Output.emplace( + torch::empty( + {m, n}, + torch::TensorOptions().device(input_device).dtype(torch::kFloat32) + ) + ); + } + + void *Iptr = Input.data_ptr(); + void *Wptr = Weight.has_value() ? Weight.value().data_ptr() : nullptr; + void *Optr = Output.value().data_ptr(); + + CONFIG_SWITCH(n, [&]{ + rmsnorm< + ElementIn, ElementOut, ElementWeight, + MAX_HIDDEN_SIZE, NUM_THR_PER_CTA + >( + Iptr, Wptr, + Optr, + eps, m, n, + at::hip::getCurrentHIPStream().stream() + ); + }); + + return Output; +} + +void register_rms_norm(pybind11::module_ &m) { + m.def("rms_norm_cuda", &rms_norm); +} + diff --git a/turbodiffusion/ops/norm/rmsnorm_hip.hpp b/turbodiffusion/ops/norm/rmsnorm_hip.hpp new file mode 100644 index 0000000..b4dbda3 --- /dev/null +++ b/turbodiffusion/ops/norm/rmsnorm_hip.hpp @@ -0,0 +1,166 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#pragma once + +#include "common/common_hip.hpp" +#include "common/load.hpp" +#include "common/store_hip.hpp" +#include "common/launch_hip.hpp" + +// Helper for output type conversion +namespace turbo_norm { +template +__device__ __forceinline__ T from_float(float val) { + return static_cast(val); +} +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float val) { + return hip_bfloat16(val); +} +} // namespace turbo_norm + + +template < + class InputDtype_, + class OutputDtype_, + class WeightDtype_, + int MaxHiddenSize_, + int NumThrPerCta_, + bool IsEven +> +class RMSNorm { +public: + using InputDtype = InputDtype_; + using OutputDtype = OutputDtype_; + using WeightDtype = WeightDtype_; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int MaxHiddenSize = MaxHiddenSize_; + + static constexpr size_t ShmSize = 32; + static constexpr int NumElementPerThread = MaxHiddenSize / NumThrPerCta; + + static_assert(MaxHiddenSize % NumThrPerCta == 0); + + struct Params { + void const *Iptr; + void const *Wptr; + void *Optr; + float eps; + int64_t m; + int64_t n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3(m); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *shared_data) { + int const blk_m = blockIdx.x; + int const blk_n = 1; + int tidx = threadIdx.x; + float x[NumElementPerThread]; + + // load + Loader loader; + loader.load(params.Iptr, x, params.m, params.n, blk_m, 0, tidx); + + // rms reduction + float rms = sqrtf(_reduce_square(x, shared_data) / params.n + params.eps); + + // load weight + Loader weight_loader; + float w[NumElementPerThread]; + loader.load(params.Wptr, w, 1, params.n, 0, 0, tidx); + + // norm + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + x[i] = w[i] * x[i] / rms ; + + // save y + OutputDtype *output_reg = (OutputDtype*)x; + if constexpr (!std::is_same_v) { + output_reg = (OutputDtype*)w; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + output_reg[i] = turbo_norm::from_float(x[i]); + } + Saver saver; + saver.store(params.Optr, nullptr, output_reg, 0, params.m, params.n, blk_m, 0, tidx); + + } + +private: + CUTLASS_DEVICE + float _reduce_square(float *reg, char *shared_data) { + // thread + float sum_square = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + sum_square += reg[i] * reg[i]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i >>= 1) { + sum_square += __shfl_down(sum_square, i, 32); + } + if (threadIdx.x == 0) { + *(float*)shared_data = 0; + } + __syncthreads(); + + if (threadIdx.x % 32 == 0) { + atomicAdd((float*)shared_data, sum_square); + } + + __syncthreads(); + sum_square = *(float*)shared_data; + return sum_square; + } +}; + + +template < + class InputDtype, + class OutputDtype, + class WeightDtype, + int MaxHiddenSize, + int NumThrPerCta +> +bool rmsnorm( + void const *Iptr, void const *Wptr, + void *Optr, float eps, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(n % MaxHiddenSize == 0, IsEven, [&] { + using Kernel = RMSNorm< + InputDtype, OutputDtype, WeightDtype, + MaxHiddenSize, NumThrPerCta, + IsEven>; + using Arguments = typename Kernel::Arguments; + Arguments args = { + Iptr, Wptr, Optr, eps, m, n + }; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + launch_kernel(params, grid_shape, cta_shape, ShmSize, stream); + }); + return true; +} \ No newline at end of file diff --git a/turbodiffusion/ops/quant/quant.hip b/turbodiffusion/ops/quant/quant.hip new file mode 100644 index 0000000..878d9c3 --- /dev/null +++ b/turbodiffusion/ops/quant/quant.hip @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + * + * HIP/ROCm quantization kernel. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/common_hip.hpp" +#include "quant/quant_hip.hpp" + +auto quant( + torch::Tensor const& Input, + std::optional& Output, + std::optional& Output_S +) { + using ElementOut = int8_t; + static constexpr int BlockSize = 128; + static constexpr int NumThrPerCta = 256; + + int64_t m = Input.size(0); + int64_t n = Input.size(1); + torch::Device const input_device = Input.device(); + + create_tensor(input_device, Output, Output_S, m, n); + + ElementOut *Optr = (ElementOut*)Output.value().data_ptr(); + float *OSptr = Output_S.value().data_ptr(); + + switch (Input.scalar_type()) { + case torch::kHalf: { + __half *Iptr = (__half*)Input.data_ptr(); + quantization<__half, BlockSize, NumThrPerCta>( + Iptr, Optr, OSptr, m, n, at::hip::getCurrentHIPStream().stream() + ); + break; + } + + case torch::kBFloat16: { + hip_bfloat16 *Iptr = (hip_bfloat16*)Input.data_ptr(); + quantization( + Iptr, Optr, OSptr, m, n, at::hip::getCurrentHIPStream().stream() + ); + break; + } + + default: { + std::cerr << "Observing: " << Input.scalar_type() << " for the input datatype which is invalid"; + throw std::runtime_error("Unsupported input data type for quantize_to_fp4."); + } + } + + return std::make_tuple(Output, Output_S); +} + +void register_quant(pybind11::module_ &m) { + m.def("quant_cuda", &quant); +} + diff --git a/turbodiffusion/ops/quant/quant_hip.hpp b/turbodiffusion/ops/quant/quant_hip.hpp new file mode 100644 index 0000000..d4f96f6 --- /dev/null +++ b/turbodiffusion/ops/quant/quant_hip.hpp @@ -0,0 +1,194 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Citation (please cite if you use this code): + * + * @article{zhang2025turbodiffusion, + * title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + * author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + * journal={arXiv preprint arXiv:2512.16093}, + * year={2025} + * } + */ + +#pragma once + +#include +#include +#include "cutlass/numeric_conversion_hip.h" + +#include "common/load.hpp" +#include "common/store_hip.hpp" +#include "common/launch_hip.hpp" + +template < + class InputDtype_, + int NumThrPerCta_, + bool IsEvenM, + bool IsEvenN +> +class Quantization { +public: + using InputDtype = InputDtype_; + using OutputDtype = int8_t; + using FPConverter = cutlass::NumericConverter; + + static constexpr int BlockSize = 128; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int NumElementPerThread = BlockSize * BlockSize / NumThrPerCta; + static constexpr int NumThrPerRow = BlockSize / NumElementPerThread; + + static_assert(BlockSize * BlockSize % NumThrPerCta == 0); + static_assert(NumThrPerCta % BlockSize == 0); + + static constexpr size_t ShmSize = 32; + + static constexpr float int8_max = 128.f; + + struct Params { + void const *Iptr; + void *Optr; + void *OSptr; + int64_t const m; + int64_t const n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3( + cdiv(n, BlockSize), + cdiv(m, BlockSize) + ); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3( + NumThrPerCta, 1, 1 + ); + } + + CUTLASS_DEVICE + void quantization( + float *float_reg, + void *Optr, void *OSptr, + int64_t const m, int64_t const n, + int blk_m, int blk_n, int tidx, + char *shared_data + ) { + + OutputDtype output_reg[NumElementPerThread]; + + + Saver saver; + + float amax = _reduce_amax(float_reg, (float*)shared_data); + + + _quantization(float_reg, output_reg, int8_max / amax); + + float scale_inv = amax / int8_max; + + saver.store(Optr, OSptr, output_reg, scale_inv, m, n, blk_m, blk_n, tidx); + + __syncthreads(); + } + + + CUTLASS_DEVICE + void operator()(Params const& params, char *shared_data) { + int blk_m = blockIdx.y; + int blk_n = blockIdx.x; + int tidx = threadIdx.x; + + float float_reg[NumElementPerThread]; + + // load float32 data + Loader loader; + loader.load(params.Iptr, float_reg, params.m, params.n, blk_m, blk_n, tidx); + quantization( + float_reg, params.Optr, params.OSptr, params.m, params.n, blk_m, blk_n, tidx, shared_data + ); + } + +private: + + CUTLASS_DEVICE float + _reduce_amax(float *reg, float *smem_ptr) { + float amax = 1e-8f; + // thread reduction + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) + amax = fmaxf(amax, fabsf(reg[i])); + + // warp reduction - use __shfl_xor for HIP (no sync needed on AMD) + CUTLASS_PRAGMA_UNROLL + for (int i = 16; i >= 1; i /= 2) { + amax = fmaxf( + __shfl_xor(amax, i, 32), + amax + ); + } + + // cta reduction + if (threadIdx.x == 0) { + *smem_ptr = 0; + } + __syncthreads(); + + atomicMax((uint32_t*)smem_ptr, reinterpret_cast(amax)); + + __syncthreads(); + + amax = *smem_ptr; + + return amax; + } + + CUTLASS_DEVICE void + _quantization(float *float_reg, OutputDtype *out_reg, float scale) { + FPConverter converter; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + out_reg[i] = converter(float_reg[i] * scale); + } + + } +}; + +template < + class InputDtype, + int BlockSize, + int NumThrPerCta +> +bool quantization( + void const *Iptr, void *Optr, void *OSptr, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % BlockSize == 0, IsEvenM, [&] { + BOOL_SWITCH(n % BlockSize == 0, IsEvenN, [&] { + using Kernel = Quantization< + InputDtype, NumThrPerCta, IsEvenM, IsEvenN>; + using Arguments = typename Kernel::Arguments; + Arguments args = { + Iptr, Optr, OSptr, + m, n + }; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + launch_kernel(params, grid_shape, cta_shape, ShmSize, stream); + }); + }); + + return true; +} diff --git a/turbodiffusion/ops/quant/quant_rocm.hpp b/turbodiffusion/ops/quant/quant_rocm.hpp new file mode 100644 index 0000000..4647b7a --- /dev/null +++ b/turbodiffusion/ops/quant/quant_rocm.hpp @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2025 by TurboDiffusion team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * + * Quantization kernel for AMD GPUs using HIP. + */ + +#pragma once + +#include +#include +#include +#include "common/platform.hpp" + +// Helper for input type to float conversion (handles half types properly) +template +TURBO_DEVICE TURBO_INLINE float input_to_float(T val); + +template <> +TURBO_DEVICE TURBO_INLINE float input_to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE float input_to_float(hip_bfloat16 val) { + return static_cast(val); +} + +template <> +TURBO_DEVICE TURBO_INLINE float input_to_float(float val) { + return val; +} + +TURBO_HOST_DEVICE int64_t cdiv(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +#define MIN(a, b) ((a) > (b) ? (b) : (a)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ +[&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return (__VA_ARGS__)(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return (__VA_ARGS__)(); \ + } \ +}() + +template < + class InputDtype_, + int NumThrPerCta_, + bool IsEvenM, + bool IsEvenN +> +class QuantizationHIP { +public: + using InputDtype = InputDtype_; + using OutputDtype = int8_t; + + static constexpr int BlockSize = 128; + static constexpr int NumThrPerCta = NumThrPerCta_; + static constexpr int NumElementPerThread = BlockSize * BlockSize / NumThrPerCta; + static constexpr int NumThrPerRow = BlockSize / NumElementPerThread; + + static_assert(BlockSize * BlockSize % NumThrPerCta == 0); + static_assert(NumThrPerCta % BlockSize == 0); + + static constexpr size_t ShmSize = 32; + static constexpr float int8_max = 128.f; + + struct Params { + void const* Iptr; + void* Optr; + void* OSptr; + int64_t const m; + int64_t const n; + }; + + using Arguments = Params; + + static Params to_underlying_arguments(Arguments const& args) { + return args; + } + + static dim3 get_grid_size(int64_t m, int64_t n) { + return dim3( + cdiv(n, BlockSize), + cdiv(m, BlockSize) + ); + } + + static dim3 get_cta_size(int64_t m, int64_t n) { + return dim3(NumThrPerCta, 1, 1); + } + + TURBO_DEVICE + void operator()(Params const& params, char* shared_data) { + int blk_m = blockIdx.y; + int blk_n = blockIdx.x; + int tidx = threadIdx.x; + + float float_reg[NumElementPerThread]; + + // Load input data + load_input(params.Iptr, float_reg, params.m, params.n, blk_m, blk_n, tidx); + + // Quantize + quantize(float_reg, params.Optr, params.OSptr, params.m, params.n, blk_m, blk_n, tidx, shared_data); + } + +private: + TURBO_DEVICE + void load_input(void const* input_ptr, float* thr_output_reg, + int64_t m, int64_t n, int blk_m, int blk_n, int tid) { + int thr_m_offset = tid / NumThrPerRow; + int thr_n_offset = (tid % NumThrPerRow) * NumElementPerThread; + + int64_t global_m = blk_m * BlockSize + thr_m_offset; + int64_t global_n = blk_n * BlockSize + thr_n_offset; + + InputDtype const* input = reinterpret_cast(input_ptr); + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEvenM && IsEvenN) { + thr_output_reg[i] = input_to_float(input[global_m * n + global_n + i]); + } else { + if (global_m < m && (global_n + i) < n) { + thr_output_reg[i] = input_to_float(input[global_m * n + global_n + i]); + } else { + thr_output_reg[i] = 0.0f; + } + } + } + } + + TURBO_DEVICE + void quantize(float* float_reg, void* Optr, void* OSptr, + int64_t m, int64_t n, int blk_m, int blk_n, int tidx, char* shared_data) { + OutputDtype output_reg[NumElementPerThread]; + + float amax = reduce_amax(float_reg, (float*)shared_data); + + float scale = int8_max / amax; + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + float val = float_reg[i] * scale; + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); + output_reg[i] = static_cast(val); + } + + float scale_inv = amax / int8_max; + + // Store output + store_output(Optr, OSptr, output_reg, scale_inv, m, n, blk_m, blk_n, tidx); + + __syncthreads(); + } + + TURBO_DEVICE + float reduce_amax(float* reg, float* smem_ptr) { + float amax = 1e-8f; + + // Thread reduction + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + amax = fmaxf(amax, fabsf(reg[i])); + } + + __syncwarp(); + + // Warp reduction + TURBO_PRAGMA_UNROLL + for (int i = 16; i >= 1; i /= 2) { + amax = fmaxf(__shfl_xor(amax, i, 32), amax); + } + + // CTA reduction + if (threadIdx.x == 0) { + *smem_ptr = 0; + } + __syncthreads(); + + atomicMax((unsigned int*)smem_ptr, __float_as_uint(amax)); + + __syncthreads(); + + amax = __uint_as_float(*(unsigned int*)smem_ptr); + return amax; + } + + TURBO_DEVICE + void store_output(void* Optr, void* OSptr, OutputDtype* reg, float scale_inv, + int64_t m, int64_t n, int blk_m, int blk_n, int tid) { + int thr_m_offset = tid / NumThrPerRow; + int thr_n_offset = (tid % NumThrPerRow) * NumElementPerThread; + + int64_t global_m = blk_m * BlockSize + thr_m_offset; + int64_t padded_n = cdiv(n, BlockSize) * BlockSize; + int64_t global_n = blk_n * BlockSize + thr_n_offset; + + OutputDtype* output = reinterpret_cast(Optr); + + TURBO_PRAGMA_UNROLL + for (int i = 0; i < NumElementPerThread; ++i) { + if (IsEvenM && IsEvenN) { + output[global_m * padded_n + global_n + i] = reg[i]; + } else { + if (global_m < m && (global_n + i) < n) { + output[global_m * padded_n + global_n + i] = reg[i]; + } + } + } + + if (tid == 0) { + float* scale_ptr = reinterpret_cast(OSptr); + scale_ptr[blk_m * cdiv(n, BlockSize) + blk_n] = scale_inv; + } + } +}; + +template +__global__ void quant_kernel_hip(typename Kernel::Params const params) { + extern __shared__ char smem[]; + Kernel op; + op(params, smem); +} + +template < + class InputDtype, + int BlockSize, + int NumThrPerCta +> +bool quantization_hip( + void const* Iptr, void* Optr, void* OSptr, + int64_t m, int64_t n, + hipStream_t stream = nullptr +) { + BOOL_SWITCH(m % BlockSize == 0, IsEvenM, [&] { + BOOL_SWITCH(n % BlockSize == 0, IsEvenN, [&] { + using Kernel = QuantizationHIP; + using Arguments = typename Kernel::Arguments; + + Arguments args = {Iptr, Optr, OSptr, m, n}; + auto params = Kernel::to_underlying_arguments(args); + auto grid_shape = Kernel::get_grid_size(m, n); + auto cta_shape = Kernel::get_cta_size(m, n); + static constexpr size_t ShmSize = Kernel::ShmSize; + + hipLaunchKernelGGL(quant_kernel_hip, grid_shape, cta_shape, ShmSize, stream, params); + }); + }); + + return true; +} + diff --git a/turbodiffusion/rcm/networks/wan2pt1.py b/turbodiffusion/rcm/networks/wan2pt1.py index 1f55c09..a0b28a3 100644 --- a/turbodiffusion/rcm/networks/wan2pt1.py +++ b/turbodiffusion/rcm/networks/wan2pt1.py @@ -29,14 +29,17 @@ flash_apply_rotary_emb = None print("flash_attn is not installed.") -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed import ProcessGroup + +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, get_process_group_ranks + from torch.distributed._composable.fsdp import fully_shard + from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast from imaginaire.utils import log from rcm.utils.a2a_cp import MinimalA2AAttnOp from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig -from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast T5_CONTEXT_TOKEN_NUMBER = 512 FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 @@ -543,7 +546,7 @@ def __init__( Epsilon value for normalization layers """ - super().__init__() + super().__init__() assert model_type in ["t2v", "i2v", "flf2v"] self.model_type = model_type diff --git a/turbodiffusion/rcm/networks/wan2pt1_jvp.py b/turbodiffusion/rcm/networks/wan2pt1_jvp.py index cd06cf8..cb521c4 100644 --- a/turbodiffusion/rcm/networks/wan2pt1_jvp.py +++ b/turbodiffusion/rcm/networks/wan2pt1_jvp.py @@ -21,10 +21,12 @@ import torch.nn as nn from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, get_process_group_ranks + from torch.distributed._composable.fsdp import fully_shard + from imaginaire.utils import log from rcm.utils.a2a_cp import MinimalA2AAttnOp from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig diff --git a/turbodiffusion/rcm/networks/wan2pt2.py b/turbodiffusion/rcm/networks/wan2pt2.py index fc41564..55b6620 100644 --- a/turbodiffusion/rcm/networks/wan2pt2.py +++ b/turbodiffusion/rcm/networks/wan2pt2.py @@ -29,14 +29,17 @@ flash_apply_rotary_emb = None print("flash_attn is not installed.") -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard +from torch.distributed import ProcessGroup from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +if torch.distributed.is_available(): + from torch.distributed import get_process_group_ranks + from torch.distributed._composable.fsdp import fully_shard + from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast + from imaginaire.utils import log from rcm.utils.a2a_cp import MinimalA2AAttnOp from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig -from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast T5_CONTEXT_TOKEN_NUMBER = 512 FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 diff --git a/turbodiffusion/rcm/utils/context_parallel.py b/turbodiffusion/rcm/utils/context_parallel.py index af27ffa..06f1b0b 100644 --- a/turbodiffusion/rcm/utils/context_parallel.py +++ b/turbodiffusion/rcm/utils/context_parallel.py @@ -16,8 +16,11 @@ import torch from torch import Tensor -from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size -from torch.distributed.utils import _verify_param_shape_across_processes +if torch.distributed.is_available(): + from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size + from torch.distributed.utils import _verify_param_shape_across_processes +else: + from torch.distributed import ProcessGroup from imaginaire.utils import distributed From ffce8b99322672e69c17e7dea76d0a46745ed6fc Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 30 Dec 2025 18:33:46 +0530 Subject: [PATCH 2/4] Add AMD Windows setup guide --- README_AMD_WINDOWS.md | 195 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 README_AMD_WINDOWS.md diff --git a/README_AMD_WINDOWS.md b/README_AMD_WINDOWS.md new file mode 100644 index 0000000..65c1b24 --- /dev/null +++ b/README_AMD_WINDOWS.md @@ -0,0 +1,195 @@ +# TurboDiffusion - AMD ROCm on Windows Setup Guide + +This guide explains how to build and run TurboDiffusion on Windows with AMD GPUs using ROCm. + +> **Note:** These steps should also work on Linux with minor modifications (use bash commands instead of PowerShell, `source venv/bin/activate` instead of `.\venv\Scripts\Activate.ps1`, and skip the Visual Studio environment setup). However, Linux support has not been tested yet and may have issues. + +## Supported Hardware + +TurboDiffusion on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1103, gfx1151). + +## Prerequisites + +- Windows 10/11 +- Python 3.11, 3.12, or 3.13 +- Visual Studio 2022 with C++ build tools +- AMD Adrenaline driver (latest recommended) + +## Installation + +### 1. Install ROCm and PyTorch from TheRock + +Follow the instructions at [ROCm/TheRock RELEASES.md](https://github.com/ROCm/TheRock/blob/main/RELEASES.md) to install ROCm and PyTorch wheels for your GPU architecture. + +#### Create a Virtual Environment + +```powershell +python -m venv venv +.\venv\Scripts\Activate.ps1 +``` + +#### Install PyTorch (includes ROCm SDK as dependency) + +For **gfx1151** (AMD Strix Halo iGPU): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ --pre torch torchaudio torchvision +``` + +For **gfx110X** (RX 7900 XTX, RX 7800 XT, RX 7700S, Radeon 780M): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/ --pre torch torchaudio torchvision +``` + +For **gfx120X** (RX 9060, RX 9070): +```powershell +pip install --index-url https://rocm.nightlies.amd.com/v2/gfx120X-all/ --pre torch torchaudio torchvision +``` + +#### Initialize ROCm SDK + +```powershell +rocm-sdk init +``` + +#### Install Triton with AMD Windows Support + +Install triton-windows and apply the AMD Windows support patches: + +```powershell +pip install triton-windows +``` + +**Temporary Fix:** Until [PR #179](https://github.com/woct0rdho/triton-windows/pull/179) is merged, you need to manually copy the patched files into your venv: + +1. Clone the PR branch: +```powershell +git clone --branch jam/windows_amd https://github.com/woct0rdho/triton-windows.git triton-windows-patch +``` + +2. Copy the patched files to your venv: +```powershell +$TRITON_AMD = ".\venv\Lib\site-packages\triton\backends\amd" +Copy-Item "triton-windows-patch\third_party\amd\backend\driver.py" "$TRITON_AMD\driver.py" -Force +Copy-Item "triton-windows-patch\third_party\amd\backend\driver.c" "$TRITON_AMD\driver.c" -Force +Copy-Item "triton-windows-patch\third_party\amd\backend\compiler.py" "$TRITON_AMD\compiler.py" -Force +Copy-Item "triton-windows-patch\python\triton\runtime\build.py" ".\venv\Lib\site-packages\triton\runtime\build.py" -Force +``` + +### 2. Set Environment Variables + +Open a PowerShell terminal and run: + +```powershell +# Activate Visual Studio environment +cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } + +# Activate the virtual environment +.\venv\Scripts\Activate.ps1 + +# Set ROCm paths using rocm-sdk +$ROCM_ROOT = (rocm-sdk path --root).Trim() +$ROCM_BIN = (rocm-sdk path --bin).Trim() +$env:ROCM_HOME = $ROCM_ROOT +$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH" + +# Set compiler and build settings +$env:CC = "clang-cl" +$env:CXX = "clang-cl" +$env:DISTUTILS_USE_SDK = "1" + +# Enable experimental features +$env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE" +$env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1" + +# Set PYTHONPATH for TurboDiffusion +$env:PYTHONPATH = "turbodiffusion" +``` + +### 3. Install Dependencies + +```powershell +pip install -r requirements.txt +``` + +### 4. Build and Install TurboDiffusion + +```powershell +cd +pip install --no-build-isolation -e . +``` + +### 5. Install SpargeAttn (Optional, for sparse attention) + +If you want to use sparse attention with TurboDiffusion: + +```powershell +cd +pip install --no-build-isolation -v . +``` + +## Running Inference + +### Text-to-Video with Wan2.1 + +```powershell +# Make sure environment variables are set (see step 2) + +python turbodiffusion/inference/wan2.1_t2v_infer.py ` + --model Wan2.1-1.3B ` + --dit_path checkpoints/TurboWan2.1-T2V-1.3B-480P-quant.pth ` + --resolution 480p ` + --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage." ` + --num_samples 1 ` + --num_steps 4 ` + --quant_linear ` + --attention_type sagesla ` + --sla_topk 0.1 +``` + +### Available Attention Types + +- `sdpa` - PyTorch Scaled Dot Product Attention +- `sagesla` - SageAttention with Sparse Linear Attention (requires SpargeAttn) + +## Environment Variable Summary + +| Variable | Value | Description | +|----------|-------|-------------| +| `CC` | `clang-cl` | C compiler | +| `CXX` | `clang-cl` | C++ compiler | +| `DISTUTILS_USE_SDK` | `1` | Use SDK for distutils | +| `ROCM_HOME` | `` | ROCm SDK root path | +| `PATH` | Include LLVM and ROCm bin | Required for hipcc, clang, lld-link | +| `FLASH_ATTENTION_TRITON_AMD_ENABLE` | `TRUE` | Enable Triton Flash Attention on AMD | +| `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL` | `1` | Enable experimental aotriton kernels | +| `PYTHONPATH` | `turbodiffusion` | Include turbodiffusion module | + +## Known Issues + +1. **Triton compiler warnings** - You may see `clang-cl: warning: unknown argument ignored` warnings during first run. These are harmless. + +2. **First run is slow** - Triton and MIOpen kernels are compiled on first use and cached. Subsequent runs will be faster. + +3. **No FP8 support on RDNA3** - RDNA3 GPUs don't support FP8, so FP16/BF16 kernels are used. + +## Troubleshooting + +### "LoadLibrary failed" or "cannot find amdhip64.dll" + +Make sure you ran `rocm-sdk init` after installing the ROCm SDK packages. + +### "LINK : fatal error LNK1104: cannot open file 'python312.lib'" + +Ensure Visual Studio environment is activated before building: +```powershell +cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } } +``` + +### "PermissionError" when compiling Triton kernels + +This is a known Windows issue with temp file handling. Make sure you've applied the AMD Windows support patches from [PR #179](https://github.com/woct0rdho/triton-windows/pull/179) as described in the installation steps above. + +### "flash_attn is not installed" warning + +This warning is expected. Flash Attention is not available on AMD GPUs, but Triton-based attention is used instead when `FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE` is set. + From 4b7674c117be529f276b0968b8da9597b108f35f Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Mon, 5 Jan 2026 17:05:03 +0530 Subject: [PATCH 3/4] Add numeric_conversion_hip header --- .../ops/common/numeric_conversion_hip.hpp | 51 +++++++++++++++++++ turbodiffusion/ops/quant/quant_hip.hpp | 2 +- 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 turbodiffusion/ops/common/numeric_conversion_hip.hpp diff --git a/turbodiffusion/ops/common/numeric_conversion_hip.hpp b/turbodiffusion/ops/common/numeric_conversion_hip.hpp new file mode 100644 index 0000000..c8fccf9 --- /dev/null +++ b/turbodiffusion/ops/common/numeric_conversion_hip.hpp @@ -0,0 +1,51 @@ +// Compatibility header for CUTLASS numeric conversion on HIP/ROCm +// This provides a minimal subset of CUTLASS functionality needed for TurboDiffusion + +#pragma once + +#include +#include + +namespace cutlass { + +// FloatRoundStyle enum (subset of CUTLASS) +enum class FloatRoundStyle { + round_to_nearest = 0, + round_toward_zero = 1, + round_toward_infinity = 2, + round_toward_neg_infinity = 3, +}; + +// NumericConverter template - provides float to int8 conversion with rounding +template +struct NumericConverter { + __device__ __host__ __forceinline__ + To operator()(From const& val) const { + return static_cast(val); + } +}; + +// Specialization for float to int8_t with round_to_nearest +template <> +struct NumericConverter { + __device__ __host__ __forceinline__ + int8_t operator()(float val) const { + // Round to nearest and clamp to int8 range [-128, 127] + val = fmaxf(-128.0f, fminf(127.0f, rintf(val))); + return static_cast(val); + } +}; + +// Specialization for float to int8_t with round_toward_zero +template <> +struct NumericConverter { + __device__ __host__ __forceinline__ + int8_t operator()(float val) const { + // Truncate and clamp to int8 range [-128, 127] + val = fmaxf(-128.0f, fminf(127.0f, truncf(val))); + return static_cast(val); + } +}; + +} // namespace cutlass + diff --git a/turbodiffusion/ops/quant/quant_hip.hpp b/turbodiffusion/ops/quant/quant_hip.hpp index d4f96f6..60d9bdd 100644 --- a/turbodiffusion/ops/quant/quant_hip.hpp +++ b/turbodiffusion/ops/quant/quant_hip.hpp @@ -18,7 +18,7 @@ #include #include -#include "cutlass/numeric_conversion_hip.h" +#include "common/numeric_conversion_hip.hpp" #include "common/load.hpp" #include "common/store_hip.hpp" From 6199e67116b0d513d6876f44149d436bca50d596 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Mon, 5 Jan 2026 19:03:13 +0530 Subject: [PATCH 4/4] Update README_AMD_WINDOWS.md --- README_AMD_WINDOWS.md | 37 +++++++------------------------------ 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/README_AMD_WINDOWS.md b/README_AMD_WINDOWS.md index 65c1b24..52f013c 100644 --- a/README_AMD_WINDOWS.md +++ b/README_AMD_WINDOWS.md @@ -6,7 +6,7 @@ This guide explains how to build and run TurboDiffusion on Windows with AMD GPUs ## Supported Hardware -TurboDiffusion on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1103, gfx1151). +TurboDiffusion on Windows has been tested with RDNA3/RDNA3.5 GPUs (gfx1100, gfx1101, gfx1102, gfx1151). ## Prerequisites @@ -53,28 +53,10 @@ rocm-sdk init #### Install Triton with AMD Windows Support -Install triton-windows and apply the AMD Windows support patches: - ```powershell pip install triton-windows ``` -**Temporary Fix:** Until [PR #179](https://github.com/woct0rdho/triton-windows/pull/179) is merged, you need to manually copy the patched files into your venv: - -1. Clone the PR branch: -```powershell -git clone --branch jam/windows_amd https://github.com/woct0rdho/triton-windows.git triton-windows-patch -``` - -2. Copy the patched files to your venv: -```powershell -$TRITON_AMD = ".\venv\Lib\site-packages\triton\backends\amd" -Copy-Item "triton-windows-patch\third_party\amd\backend\driver.py" "$TRITON_AMD\driver.py" -Force -Copy-Item "triton-windows-patch\third_party\amd\backend\driver.c" "$TRITON_AMD\driver.c" -Force -Copy-Item "triton-windows-patch\third_party\amd\backend\compiler.py" "$TRITON_AMD\compiler.py" -Force -Copy-Item "triton-windows-patch\python\triton\runtime\build.py" ".\venv\Lib\site-packages\triton\runtime\build.py" -Force -``` - ### 2. Set Environment Variables Open a PowerShell terminal and run: @@ -105,25 +87,20 @@ $env:TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL = "1" $env:PYTHONPATH = "turbodiffusion" ``` -### 3. Install Dependencies - -```powershell -pip install -r requirements.txt -``` - -### 4. Build and Install TurboDiffusion +### 3. Build and Install TurboDiffusion ```powershell cd pip install --no-build-isolation -e . ``` -### 5. Install SpargeAttn (Optional, for sparse attention) +### 4. Install SpargeAttn (Optional, for sparse attention) -If you want to use sparse attention with TurboDiffusion: +If you want to use sparse attention with TurboDiffusion, clone the AMD Windows fork: ```powershell -cd +git clone --branch jam/amd_windows https://github.com/jammm/SpargeAttn.git +cd SpargeAttn pip install --no-build-isolation -v . ``` @@ -187,7 +164,7 @@ cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Bu ### "PermissionError" when compiling Triton kernels -This is a known Windows issue with temp file handling. Make sure you've applied the AMD Windows support patches from [PR #179](https://github.com/woct0rdho/triton-windows/pull/179) as described in the installation steps above. +This is a known Windows issue with temp file handling. Make sure you're using the latest `triton-windows` package (`pip install --upgrade triton-windows`). ### "flash_attn is not installed" warning