From 4b75ddfe60b3e33012dcf581f78bfe94dd460ad9 Mon Sep 17 00:00:00 2001 From: Sue0515 Date: Mon, 24 Nov 2025 18:43:26 -0800 Subject: [PATCH 1/3] refactored codes for LLaVa setting --- .gitignore | 10 ++ BLIP_MRI/environment_llava.yaml | 6 +- .../__pycache__/__init__.cpython-311.pyc | Bin 167 -> 0 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 149 -> 0 bytes .../datamodule_rsfMRI.cpython-311.pyc | Bin 31331 -> 0 bytes .../datamodule_rsfMRI.cpython-39.pyc | Bin 13705 -> 0 bytes .../__pycache__/dataset_T1.cpython-311.pyc | Bin 21453 -> 0 bytes .../__pycache__/dataset_T1.cpython-39.pyc | Bin 10546 -> 0 bytes .../dataset_rsfMRI.cpython-311.pyc | Bin 23851 -> 0 bytes .../__pycache__/dataset_rsfMRI.cpython-39.pyc | Bin 12761 -> 0 bytes BLIP_MRI/project/dataset/dataset_T1_LLaVa.py | 18 +- .../__pycache__/Bblip_t5.cpython-311.pyc | Bin 7113 -> 7139 bytes .../__pycache__/__init__.cpython-311.pyc | Bin 165 -> 191 bytes BLIP_MRI/project/utils/Trainer.py | 164 +++++++++++++----- .../utils/__pycache__/Trainer.cpython-311.pyc | Bin 20035 -> 0 bytes .../utils/__pycache__/Trainer.cpython-39.pyc | Bin 11530 -> 0 bytes .../__pycache__/__init__.cpython-311.pyc | Bin 165 -> 0 bytes .../utils/__pycache__/__init__.cpython-39.pyc | Bin 147 -> 0 bytes .../utils/__pycache__/data.cpython-311.pyc | Bin 12276 -> 0 bytes .../utils/__pycache__/data.cpython-39.pyc | Bin 7990 -> 0 bytes .../utils/__pycache__/utils.cpython-311.pyc | Bin 2575 -> 0 bytes .../utils/__pycache__/utils.cpython-39.pyc | Bin 1778 -> 0 bytes .../BLIP_MRI_Blip_DDP_interactive.sh | 0 .../BLIP_MRI_Blip_T1_DDP_interactive.sh | 0 .../BLIP_MRI_LLaVa_T1_DDP_interactive.sh | 10 +- 25 files changed, 152 insertions(+), 56 deletions(-) create mode 100644 .gitignore delete mode 100644 BLIP_MRI/project/dataset/__pycache__/__init__.cpython-311.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/__init__.cpython-39.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-311.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-39.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-311.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-39.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-311.pyc delete mode 100644 BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-39.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/Trainer.cpython-311.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/Trainer.cpython-39.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/__init__.cpython-311.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/__init__.cpython-39.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/data.cpython-311.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/data.cpython-39.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/utils.cpython-311.pyc delete mode 100644 BLIP_MRI/project/utils/__pycache__/utils.cpython-39.pyc mode change 100644 => 100755 BLIP_MRI/sample_scripts/BLIP_MRI_Blip_DDP_interactive.sh mode change 100644 => 100755 BLIP_MRI/sample_scripts/BLIP_MRI_Blip_T1_DDP_interactive.sh mode change 100644 => 100755 BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fd400d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +BLIP_MRI/project/hf_results +BLIP_MRI/project/hf_logs +BLIP_MRI/logs +BLIP_MRI/project/wandb +BLIP_MRI/project/dataset/__pychache__ +BLIP_MRI/project/model/__pychache__ +BLIP_MRI/project/utils/__pychache__ +BLIP_MRI/project/*.json +__pycache__/ +*.pyc diff --git a/BLIP_MRI/environment_llava.yaml b/BLIP_MRI/environment_llava.yaml index 688c737..2496f24 100644 --- a/BLIP_MRI/environment_llava.yaml +++ b/BLIP_MRI/environment_llava.yaml @@ -1,4 +1,4 @@ -name: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava +name: BLIP_MRI_llava channels: - conda-forge dependencies: @@ -125,7 +125,7 @@ dependencies: - sympy==1.14.0 - threadpoolctl==3.6.0 - timm==0.4.12 - - tokenizers==0.13.3 + - tokenizers>=0.20 - torch==2.8.0 - torchvision==0.23.0 - tqdm==4.67.1 @@ -138,4 +138,4 @@ dependencies: - wandb==0.17.0 - xxhash==3.5.0 - yarl==1.20.1 -prefix: /pscratch/sd/h/heehaw/anaconda/BLIP_MRI_llava +prefix: /YOUR_DIRECTORY \ No newline at end of file diff --git a/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 7a2f5298a34813273a694657519a4887e8bef44b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 167 zcmZ3^%ge<81ZCIPr-SInAOZ#$p^VRLK*n^26oz01O-8?!3`I;p{%4TnFDw0m;^d;l zlH?5i;uQT1{fyMqjKp$%Cm+v%c;6sT{eq(WtkmQZ{glL##NyNv{rLFIyv&mLc)fzk lUmP~M`6;D2sdh!IK%+pG74rj$56p~=j2{?aL=iJk3;-+-Cw2e; diff --git a/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-39.pyc b/BLIP_MRI/project/dataset/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 9b3cbebd48a3732f3f2e633503ce858fe4ecb523..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 149 zcmYe~<>g`kg0k!D(?RrO5P=LBfgA@QE@lA|DGb33nv8xc8Hzx{2;!HOenD|^QDRAQ zhJJC1eujQVYHCJexxSN+XF$Afkf(k@QGQlxa*2LQVo73gYKeY)d}dx|NqoFsLFFwD So80`A(wtN~keQ!>m;nGESR&E@ diff --git a/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/datamodule_rsfMRI.cpython-311.pyc deleted file mode 100644 index 0df9552b9815427a7ff59b5f03297107d27d7950..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 31331 zcmeHwX>1!=niy~K6i-o#`2ZYTNSBmhGl#79~@bD7i?rZ4K4a zyXmCqjyFqBJlpJ9H-^7$wM|4Q3PVVis{Opo(}|5d{cwuyBAu{(y<^WFPSxAj zKfo!^9`EDS1B-KWw>jl--#Dl4?Hd^wgv7v5 zl0!CGR0VJ9Tu@XtBe-XT+Ty%s@+rh|igGDuv^kd~k9@f91>t03;5gf)$NfdFZNiN& z2sh=zDFvl`podl$=X@YX!JRbce8o>n6;%i2KZ0@ii7*z>)~z;G-z7e?Alm+=_52dVW5_X(B#)IuPN^ z;lPc>K*%5P&IUr$k!zeS9HE0$zE7 zCY&P}T3C#D!+zgvFf{F*3q-EXQ=DTK&#OXTH5IwN5a4u?`RjpDa0%w7gHwkW7hn!+ zz20Cb81Z^JBL!#=^uQbP%>_7paLzX!@Gkfw*XR=5NDZ`VDmYE!F5oo$M4_GVLYHFT z!T@6}U5-Ho0IoQftpy)tkae*@daey{vBXNfJe6zxfF%k$za&Mj@ zP8AN!PSG_u>skzOmFPMQDgkghuh&283x~blFh=4Ua((dY{M+158Ua?%U)ylsJ=p9%OQZP_Z2=iEHCI2-Wt4QySwO;b*(F{b+OuVq%CK(OoVce&@qOsD3}HAI$97{qQ>4aRBW&o@w+vyoee*QR7I))AFbS zc@81Zsl4-~*y*Z$R${2OJ(B^zg;zXt$#f=ou50Vk+PdWN6&+LG#cB^B?IA{cDARZB z-|hbM-GA2n^JZ3i1Scp@3?O%NibAbN7;^s-nke%No)VXU$AcjeCNx5K56xC4$l%n}eDa~hD6;;7xEE-=b zZY3kw7RHNAxb4^SwllEQXal!`fJ^h&IU@=1P6clTDDRT0HF7KBQE`g-FsGgKT@O$} zI?Sms#&fEf`C!PSrgua61nW3!4wv~nLW^@;G4F(AWSYPP$K^uyfIc{vFs9A~6X|Af zaZV~Q6%4@?mPT8u72(#_QFcJ`@Y?`BkUc3om}+Lq4)S-bctfkdYlx4h&M?{zR@;HJ z9pKm~cR$$q{pRmAC*Nl}FEhu^V@Phc{9RQ3Zp@r9S25Zux&`v|3q21n*iA+@1&&W( znWE&^NjP039<2!+q7XtvIgk*g5F#u@G9Z|cMkRz`9Lio3G8kL=don0PSeUZ$H9~x0 z$-yBJVF8iJepe1-Pve9%Ubbl5>om@t#-xD8={!Zx59jRT^O2Fch1me+ZUWRGP0!OF zMSl3hxP>(}{5Tqbn&(KTcdL{UH>X?khilq%0rB}FxYKFVG{f@hwKTP~86BlN4pS6sT{Yb{+XDd)wd z7m&m^Y`ZR66xBv`Q%b_)_(j7&kP34EzdQio6WKH{OP5QR<>KnXESKdwTUwQSio31} z*%$I(;5>mgD}mklLcu9o_i*xC;z{;v{oz@l+Wz^ug`TA&qB#a!$a?`A_@2e9ck=evMtsx!-^+vCA7;hA)}8#lrLm-A3Vg3p$Ea^ja$)+b9DoNRyf=EJ1LLf zNLcSj1P@aGCCg$c2y9|awv)nLgirNoIYnrJQ~PMzcN>^75{FYS(7{lIQvy@Usc&8j z&;d>r47F3hx6b1p^ymp0e}6mlF{cN|MQj`cqA6HQFwJEYk>_=Te=oR6;uJ`+%H>JM&(KI zB!J}nC8d%Bvdcr4U_Dq?Evuu-q_8sNV&i92MJblxDJ2B3A5%l>oy(f-(uvoN!j)p{ z9PjSBJFLSZ75I}M6!&7 zKJ0rh%l<8uh|o9hJpx&fEd}f-OxAc?avy8jg-pAi$yB<%soqR|^B*1g!4bB8AFAK? zOr|vLCE%Vpu8IdU7V9TP-ziF{5|K|0tfc{28W>AMrgrx;ncNBy0mMf#j*?G?zBBY$ ze^U49_yb@YS$8|@=s=DRhTPATIIbG!tb`&-xdeQPYw#-!@_yU8zWVB`ZNo@F7;mcY z^;Y&CR{q&xRiCWr&y{k3eoSn07w$fSvnRp#Bl$;)k4mJN3wR4SD&FJAznq^f)G$Zoc4GEs9lHwZVYVus}NLq(w^|b8c5=lwI>K#>Y-UN}^qXBM1ka%Fe zC|V@NUzjO1B&K#*7u8YfEK^)^r(|nhCaRUzN-n!w8!G|fndU=nRFxDl*+LTw+PJCr z^pZ9SGhNbq`lN`v+b(Bm&7^P3hL!QzALMS%bjWJ-d>YDU;vEWyGg2@%F< zQ-~munh_*YkBWW-`1TV4=ax;8+FY5CuPG_WnDV~hZ-^SDV=wI?f(6CPW@!tyuBF8H z>#SjMp1q42qoTJL*4(meg%*lZL@lDcq#;&Gh&lhZ>{zxh+d!TtNaiFxXp7pT4odq# z87<(7Va*|7gs;_&FO|ZkA4F$0z=LLm)HP&lE@61_>?Q19D>%NZ_(T@OEUy(afOiy8 zk(~C9GGD`gthj3fc5n!i5a3bK{n$N#!C3&D5;Hy+k85f^_<(FHf#bz%3eWK#6s-O zgA@e#z(rUL;XRz~h#I~l=xHcq>GgZyDkKl%l^XXzXEz2jg=KZpbmqZwFz!oCkA%hqe#)3xnv?Os&7mv!$$ z?tQF%KeF$SX`dNoWfd{=Mp?y!vgCo)eN5R=w(KY>J9?)$rip8wx~r3#l!kS;A$J>a zdb(Bu?u^9-A{Tx2aR$kM`CTE2RcLwrF0)v|wE{jaOp zgQMu+D7>w?~0PdJ!!_$(wzTSEy4Z`Q6r6_^6Ey!aD5<}E;@?l)~sNy1iTvA-ao}-GOBB6RWR!BkFSlEhauOi45jazsOv1+% z@YTRa;01tO`7e%3kh&m8)Q56yR7LWl7eCyVHCM-Yj<@AxFfbUX91fC4? z0G^J&aBL~R0&>n*DNPMK+s&n^wku8IFL0YtWmGk-koIFitu(;t2s~-7_V|g@Mm6w# zR2eOhJ_4sh!U(6M8b1b0d=VX3nDzMs^a7;xnCJ*R=vx4Q6$wS?G5lJG!EFrihMg>0 z@THHeey~7Xn+MWOcE40`Y6=8I29VMCVNDFdKfmhmJ7mJ)kOhbS80Yg30dP7xFb7}z zfZw9;!n^1*U3PQqDx}nW!LzrQ>YqcWbCGbX!cPGJi8fl|)=!+@ao%&?cikDkJ06p7 z*h;?ZPt@LqfS2NCVC#j z!&-Wgr57@?+V8s)z8~5iR;`q-HZpsTvbDXawwJZ`A!}bu^BTkf$bfL+35touFR^I9 zkT?_yc#;rj^CcDSEaPpoEUMVF5mbmL`HVoIutvt_*iyuXuY)3zBTB1!iL|P1(t`HK z3(JKshd>%M7^|bIzpuzM>L7knz%=GV)RK@Q2vPGPluQz$#T)dfmQ1dqaZXN4cxBu| z!Ykw1l$947d3iA<7v}>qW8joMit%6kKe+hKLLa2rEIIawCIDy}05^u2R1x1zN0>+9;|?__Vnn%JA@OYD8%6g+tmp%@P^29bJWmMU9s%`!;2@s~NMYZ25{&Z!A2 zMk@P-f^2FqM;i*fNu|FH1?B|XHOCopubfGirUsJL{2Rd6t*{bUM!s)9!^Pk^ll_jz zOXXj_Z$A)KvucPT_lhI2m{KK{`8R;CTTK;VX~~B1P3!qPqIy=3Fyvm@2aIGxvY3AZ zNa{%+g{ZJMy#x$}Hgrt(7_7Esu`hNs#@6DZ9<_&eK?dEc;0S`1Zgp{*wUj zTFrcYsT8bBBehy*^JXVLX}l$JZ5^$V#bs|@1%(UTj-4F68w0aqyY|U{5&^d;jW55#v{T|^*ZnYAg zY}F6~n3L3(l-|F<_t$DSJlX0I1h8Zjm5j#p8D~9cQq+F__DYaxID(uY^Iw21D2ojJ57t{x|u(RMx6-YOPV$O<` zdKDCRJ}QL;1UzzA{%i3z68=qJ5P;BQZZ&%qEG(+L>OfpU|E2iTVV=(f-VBwLVFm~*j#zV}3fw<9eL90n|n%kN%_>q9>P%5CSd05I{&Z z9F8g}Cxo4qgp~@^t`x$CBw;Qg%mrZwC1GVkSQ&&hNy5s7uyP2il!R3XVHFT&j$pOf z>_xeyRH1;rmSj_6YE=1J^fjB3DtSFRcKO|91ANQTQ9m;K9x{`T+b>gJeE?*A{|R9WoZzoN8_e(-XVnTeDOGcP5kJt z7vL|9e705?hdi&1Ly+=yaOmaME5NhIn6f2Bvfvyg0A6volq%EzRLW_A?opFi$Rmvn zl+3dY?*1O3mZH#C$(Qfde9kXTSHGt6US&Q+@Og^t6*WA+zcBA{_wn6-AJ>E^pw`WN zJL4Mq@;39Q1#o@4y-h2GN%uBw!W`KHVb9&$d`&Z?{kb!w_-iVqLs|;8_tCzBZwh8f zNfu!|ilgSjuMx|ZsAWdj^+hdPYh6(Lska_=N=q?=^a6`cr^I_dEMnYrQ3oC!l+Ki8 z%`5Kgwe5 zk~-w(cdi$HBXx*)^5|V4({Mx`Ns;DRJ|3(K-FfOQ#Tp!j>0ksODal?@yEG1wD*QL1 zj#KgmvW0&m>Al=q{SEJLsUECNhouQwm+4LR)8c;IW-s83%R@&cEzRNT{WmH3rPo}M zY)87rOUJZuO{ZG&7_xQCepCGy`drd`xmorH-)!#*c!KwXf9v}Y??!JEc;5`0X6eQNQ2L@#SxvM*en=OK0CEYAjl>x2DF+&g>jNB%R~ru=0^{P`jC% zCTCUQ@NppGBOmbW!~G|{*c69%8gxMebgMu7-jYl3gl~HOU^w7~gZU)HB82#6XR!~Q zr-LuCnjw;yQ$jpW2iiSwJ}v@^llK1k=`+~X52xz9z0-kSYKAMC;I%lpqU^U_tSjW7 zpIw{_g?Z}+@)MkD&Nmz2G*h6$GaDeMY?n&GG6*(0Kq~I-*xTOG-oaO40;-^d+?R9% zw^~DB4q>>TIH8W%sRv&lbuI55CQd`n+ZU#mDyjLKAyDxQ_~vTBCc*WZ(0l|Avn?#l z({MVM(@&i38y)OF;~g0w2Yn^_7x*(K(PO~c;aNQf71qB19R{^TGnv<4pl2Z7X1#zt zqHn4AaClC`muSy!rF!unaW*ahpn6f!$KS~_i2>n-e05!NVH+#ng=+z564+$mhlX^V zI~2SYxq6HD!D`^%MKF&-!A~{UWDJKTs z;S9Y)gWldV-iiKRz~$bF0Zu#cj(6zv+hkLzRJFX)K4_|I2sY>>*r4Q9^f5MYdfeIM9Aifq;}a<96I}&fW#pAnqX<4_ z3p)8_CmC`l_pLhCCYgh8!<}g0p8yG#z^oiOGl z=M+Hkcl&pz8nC2Qqv%weyq1-#MJr zrv_Pj2eNm-k<^m1d*|+-1I3Z79RU{X2vk%)=t^E=E80*+TWoB@RsO*6!@h?@e>C=k zv6b@GNoM3-w!w=UysYaAa$SjyY}D32toY;8kIw(}(vL5#?p~W=PS3I}bEsvGtqq~t z5M!*uir&ZnSn;Unr^X)}S0>lWnG;@i*A=wu3hVYEw=Z@gW47P@Rw9r#*E8n&6<6{? z@6YGG_GPsF)o+>OcG zD`m`%L&$x2-QAyd_w$m>po{WN_U*mUfS^?s4QEhioc#K9d!j z_J4^%9M0c5<2UaezJEA*Fm(adccS_p#@@r)doraJ_rH}4q)VHa(k3`ynL4;~0W}@@ z3;WvnKXd=woo+gTYR4bnMzt4E?G?s%71jC^noLvs%EXVmQ{72t(z$vKJgOQJ%7ijg zR{5YJrB0W%Fl8+-omth+ly1d?b{+mx|Ju$!nf=*pdepv7nsS5sOl1A zzN;w}{;@u#PfR7IR<-%^4C%7HOxfNyE^ioBorRK^%kQEp?-u1{Jk6IDCDAIBT_6Zc^NERt-weqK}@liF>P&T ztj(**`0Ye#X|n38$kgBg~E-&NBn3{yZ)xfs+r-BZW0f-l6?B$7I%807=qpHJ<9p9PC9WYhoe1bW17L}ii_e+Q5 zt1or9L?a#&Y;}Y+)n^F%0@Px1Rm-;YqLyCZVc3d3RM8h3d%Bs{_}K;uHBhW8fLwvs z8X)H2d7e!C-Ro65vE2fQ{yKQOQ@}{3W?$O6kAb`Y+{*2hLIm7mR2JBV8>rz1>!Oj1 z#-rj&45rPsjJX!mz{S)BW=AJ-A6R!ENxP4*?q1~X%^@fRf};D5ZshJ+cMqoBgZx3C zxbCTVvObe(OqB!>R9&~%r|qB=$=dfI`<`bqwYBnJ_*ZwX)xo)-x}P<%ZyiH#9eeCy z&RjsFlW5e(yoGNbp-a%t&Of!Uo?kn?=36`cvy1G(5p-}Qy>kRrjy^t!D$k-yLSibt zIRp=&hJBl}Pd6Mxi zRNclp_Arh;PaN(qO}H_|cotw=JglkVh2zzoU+r=WU>EihFLr;ek%5xwq0i_CH@_eK zUNkkjs%HkyvDN2M^?BC*4zj-!JBH7R6~ws2R$oTdmzis~82fG3ejC|uLp*#GA#pBw zEb%sLYeu%_*x(aaz%ag z9p+s)b@UGF@*|f&HuBe68JLMG{l5Eq?&Qo$qBLIu~B>+Z42+G;9Z^^Wy{)7 zSsNcq^!g;n*8+E&>H@%)okV3PW24XouJZdgKZ$-PnjBrxvZV)6>A~3WGmXsR0Cr+A zZEj%94HzeH;RSvxx(I>U<^9^4j@ds-miObp1>~3alC*UPW8IOErTv z=hLr|@n8Z|Z0!``Hj<}4y#`BkT~lggC5W1j{zd)TG`9PeZXQQN^YQshzvw(u-r(92$PB<>i3RNE!GT0xbQo5Qf~7j~hvGXg zWnxDvhQ#JYmNoESM9k-$S_VTWQIU!#CHmA(<*e{Iqnd3L_cxmf+5MdSa`yg8qD3#A z-|490cN)cq0l?glY6(>N3)-c7i8%Vu^7G8cNy3Vx`CljFL{~pA&a&>gq54pdB)q_Q zog}QVbPyeYM50O(ryzYn>4l+siNc0(`~gm)=?WA94BNykis}nu8nLkvd?s$$6qAEu zf*KSPOofUGA6ud-`jiB3ElGR=Wz!}7*e&h< zykm*#qWbKayfZ{@M#R^@z$d;W!a>OaJ;9iTVY5XKFlAxfLPW?Zu?+TqLbg)b;57Iw zwTNUrAio6}e%J?gB|)Akk{JueI?iIjAl8Qh;SWC= zHiG^&l*&2sH3SZqgUk-rvIALm zFyziyT?xni%ES$1txcZF^4wcQ{pyLqiS3=1mM0FSte+lXOPf$>Q*3yXM{M}1*%m)| z=l#3ygI&D~pI1GsN}JmlbK6EmRpR!iJr8=~T9Ff8td95v%wYepJ{kPH^^bm6vg+7_mT};zqw)Ak!n8{5+n0Tvr_OaEY2>i6}Ck6*L zy)i4_sA*st53V=%q#JwK#-ph5C|lEuYI+%CflbZ$@^~Yxbvh~*yxQ#U#b3ND%yv3M18Eq#&(J{%&UHH@+Ihi*K8dD$HKXQCj z@{to53t2K3L-GX83_u9}0=JM;{;+laj9meF1-tP{vP|aE=bqv`3&=WXoh60EC>I|; z#c9AToI{hA`LR7>$)}nTbi%gQ`T{El(6;=UK)M`C$NjVEDvXJ(x?>xK%bHNXxMn%# zjIN>-z*22vXW@<=j(Znal@R5KYav($6ITV&71(9Sr;6EODN|x=Y@3#{5)!6@Fwl9I zE~+0pNf=(JHJid5(6X(@ThL>sXSazV*vFtjzkF*e`FBcQl=hKEI<~m2AD1lS(}uW* zMB|+AqdO&FB~3D>`Bahmd^}1U)q?47uzmN3vd>h~6mnVU&Wq|s--Luxg{=;R1#b;5L|#WrLTkw zLfr~*!TB4)1)==c#04S!Hn`vfdScoV7q%NML7i}0Tqu7DT-c842x+&(1;;YjFnkU! zIJU$El1hRL#RM13uZRmm-3oBQ@f*Sgq5Rjx1tI-5xZnU>uxyD74h0rTlw$LJGs0QJ zhw?im@Y6Up%Kkw4Q2vkP%a(0oIU$x+h-KX-mO~QDCLS*+Rp3s5o;$=e^-hVDrb2j$ zB;2_v+$jln6o!9Ru(p6zBums9wUJGtYn&*v&|d+LZ@oR-4I#qz(45sBkzw5tso_W1 zCc>`=@is<=Ak7z9O_E=O8-8+vSFa-a4qNQ38bsVZ`Bjvywuv-!nqbkePqz#X4RG$+HTyorAW$nS3z zL|8!oU+6iHoVb^aXRtx^HoVX5s#z*-uF1Z`CbO$%YjmuG83P!o+8Jla!+c(uXViEIAns#5@w2C2y z+$(z%{Yh7<{$VBm4v=q};S5-x%(cw~Lq}e}smAHBVW2mn9{)t#I`;@e?v;te1X!!- zf7r^u10=0GcE+Qm&5(k#E@Dm7#cBMDC}19U9@Ie}017UZ)a>GEmPe1}7X)zuhxOzQ z#u?lOJ_IV>Rl{4ZanlDOEnR~FMr0Z@_r!K7eHgp`8iQI41Wh+iLw>g>On-sBe-DE{ z0Kk=9_09U=SF|YJDkz9Z1UZm%HaN#?q!B}NIe*#!?cwdP<^@fHc!~Z$IM)gx*Y~jT zQPSXCdd|h?yG2Nekcq+q)8gSGR%nHI`8S9KKeUA_A_R(FaUMINM)WQW@W~}I_%!^a zpp3R-06#P+3*&7A9_)haC|tZ=TxxsO;rdI#}<(P?R0M`JYGj^YC*Gx%C59 zk(DI?tm6W5T!`tQkBpYPhmfHX89JZIOu8{SQA9xIT4X+et^Eh&neN{8?vv^6lgtG_ zb7K+Ri2e!&N={6y((;Q6ru*cVGFi_hY<(0UYw3#X6XPlPdm>KReUj~_P&XxnZy2qN z?d@b00JvA*=bxABLDnbsB;8$USaEfsNNSJi;q32-x(jv0mxeR#kGm%lrd#|B2Lvg%648totK1g(kI6g&=<$yUhUzZtZ@t(#~5@16Mc?FVnOrOl|cd8-~7P1{Eq`)Fa0TxIM1sNSFNk=%L&>(6R4 z0JdTXRSd}iPsFLZPOx29P}db9ob<;DFeGO?0r%>={F60~BjY$j z?tFhhH8iGC(jOz=AOGGs>u!c0mSam>P-)9nJ%WGFgs~44_Q)l+{xYh+obQp`8U}0D z>KFjlJ%ZdLaXfd*Y7#|^i-tSdfa>*`}){Ab{d3 z98Um8<$C-s0GCePffHweQ=v9LlL#zE9gR0- zG?p>KZw&I>1o6k}`uv()_@j8(-XH&Kd3u~m3$`A?j36)rSe#B9v0rKU`fR{Qhgt!7 z@Gpu6X5kjkaHfF0kBRj^t%AJw*FOQK@QaTwQ>9d~MHy5b#@OYY9D z?#$|EG25h}0?Tcks%ajA#4gCzFM4U(A`OBxO^YH3iWcZo&;o;^FMWxNz7^<0UfS=R znO*LZlw!l?V(#2~?$4Zi?z!ild(L6CP{=CyYyaE#*B|>WMfp!E4F2gTJd4Nw780Qd zwXWFlTeUSV>$c8s!%p$rv`xIVdb(lRRwH9)RBEr+vyGgcYvk>GqhJ>rMZ4G-u}2yu zyTtWIeY8=w%Z)L6jLWI|cw@qz;JjI%Y)sixs`8v7(!#o_2y0ujr-izrR5I@YFJ-x+ zMcJilP<^3U6|RiTmG))d4Ju}oS~;Xu zi%aLxxp+pdw;Qe(yjYb!2DrahUUJu}?Rwx}aOr^WJI`n5+_ZgMX={;C~GnSFu&0fR1x-C~KvWwz~m=*WkRPB*B6hpB~px(!#tn}Bu_|iq9-SVnAfGOp3c!&Y~ZzBnmZDmK@ zRa7NZL&ecT4XGaLLY+~NGeQGtDpWR8naL0}Dm$fJp^75W;2M)peonhAR6)AD{a?Xb)CBFtp^)X*$-q*xXxuaxaPW^gZgMRfnBQw8?_iT z;m1qt1hs}sjd+k~yymsqf#a`M>osrPX}G~gQ$*u+TL1mfln8FL+$a+?w_LB*0S}Ew zDZkwUuUU>$^J;Yt;-S@;)RoAD9CY2 zCMcPtWQvk$BvG+9Vyh~v4L>s5EkIm++6~&Rr~oF$9Fs;-?rOE(FXRK)4|*l+B*S;> zYw~^?`UoY2lKdDYvqs8-(oIA>|&K+y{tJoo{8^?TcY~$F5>uyx99XtEN z#TOm$+p(5xZn~?%vEGOV|(np&?D;@IR51LewnGe_ z3N^Ga-_duqF4!DwmtO7^7Z~SbK5cvd*ss)Mm$3OQ_j(OfYi`+ju5i|B*InUsjD_HO zP%$E{=|@(hy5)+R^dp1THZnGwHLsGA9|tx}-)OYAYVjCeyI~jOEV?m6fR*XH7#;*; z4tb1vbMiQH(WG$KY95$%xY;A9^DQKbmRF1Fq}q94fMSRDS=Fe0@<(}kn0_xJ6RJYp z(qulsc0^u7ema0;MqWq$XrOW4K>m23uPC6VdR3990?-j{Ou#P$sTCzKzYbZ*w3TU$ zD+xMqH1SDcr!DJ$9@d1ET=ktzq#(Btz@-~&GK zvjW~C4m zAFU{%zk_OT~&buN9@L zQdVz3d~%y8)si(Yh;$qyQR>=;E8WPbc_#!;nkL~}$uhe|xmdoEIQOFrN!K`=W!aj5 z>4Q9J+u|%}5UDFTDp)pCDNQ{oV{t#^_7i;m*AbMa=ogVF+PFHd71de(bsql!i!vlg zX^+wtWkIls3|K`I;MzZYp{s`~X$niq>x-{JWOh@bzN3YP(54hV%~PnE?3x4(bmH}1 z&6R~K<6nllXyCW4tBO}Tq4=*Iht5I2-^9ZrLSOD^*B+15EF`MYY|$dr0vE?y>Y7AG zeg^13OprjeGdSsdN7PmW`6SAbx)vEkc=>!>OT_0g^F^OIqO4!NnuyPfXz6E>C~96U zL2O!TXJUYEdo5RqC3}B#9+M29x&&ShP&ZRyiWI`PG*HsSzBfZNOmAzT;4C0kHxp)9h%VpKcCf#RzE_st(4bDH zp-x(%v18Fn;2YWG8`+&q@{L>{()(&Y$U^Njp?~Ez^V1Llco&j)adSjiZ|U7aPznn> zqhT)0?`qpBuqbv%s9)Tg(}Qh`&=T_@(xgwROCu1pgeT68a&BY>B-4~`xjWV!hvruf z%i&l!E;4W9Fy4h!jLR&?vdf(*9H&sf<{^&~g?ye5y#wNc1l$;{1-c=lo+6fF2}dF*79pP*U<%=a5gU6)@E&>tvaLS` z=?(=6olJVrUE5yZ%THt2p*avA5r6RT>#G&%VylFO7KsZr4*3u-W{IS@*u6#7tLqr9woqS99dynn_DaKvC0!U>dRs(27<(DaW5(#7k z2*5G|BABu@;Y%U_sob;}QHl|ZQe3plP`j|YQ0(?Pj_$#dN!kOTWKcWDY!Fk^E;0Iq zcg+JQ;E0HI?p`U;&Xb~DK+Brww@t4&+#gq5T1MP?9ilNWBUXHGim%Z;nQ3S5&XlCe zC2F>X6QXPyYFR65mYUZl)iUx&acWGfbLvslb{>9TjDFy}pfiKdi|-&qss)`AkW6Yo z=Ky3QWQVbt8Yh@O1Tz_C1h`FZX29dPbqWo#NDtCGSt>&{PbIY}oHTJQ+EXpkT&Rb} zIy!{TM`pY&z&RUDtZ78&kEpG!K{?3bkY2@R@$kg#7@tw`S`jD8CM7KzCnMbk z?q-l|@(Ka>9Z%Eymoolr{RVaR>25}c6^?o=ULMaRE!#u)ShimOA#I=ZZ4x=WXx~DH z^+6o+v~b<{2*a_@$BP!=wuGZfhlW)v(D zsS8w2aXF20ipoOavPFkrXz_Ac%aJOF1m#_|S<(%KN+ZPjVMSYh@Bi=@6ZdW^;_qPI zv3#!#&=2+oY=4r91M>U>>Z$@;;A#&SHA?Y*W@PSYiO7a!!QHp^thau(htz?YAF37a z`;rcU5s%7~8|Ez@|7#y=e4{Uk%H1AyO|XQk*}$u_!g8|??`X_+p2^-ps|EC z&>okJ+v@8_y15`1=5VxRMSe!v%wGjmlF5;=Kz<|4#&vPKE#MPosHK1++RcY~Q3CB$ z`4uthX`)P!Zo&Iu4_DOu?AMHNV0@A8k6#E2$@o7_#xHnwe`cCI@`1;;4vekI+1}V; zA7{d&8DeuZ1<1u&isisL!OkhoeJ{2_iH8BauWrj>Kw|z@Y z0rDb4rWx|CmA0Spx5OdpV!a)#oq4>RNk;*tJZf1Dvd7mp6kbf4s@B-F#~&^pEI|JZ@q12Q=Kz4dtg zJH!P$Me+Eqir#ay=c6PQ(2KHaBcO8%=@06iXK2vIR5e<_XuGR=_V4EdEf0Z1ZW_fc$-DOL>Nlo^or zPwwjCDe=j-G+f+MJEgD~j*xQPErq4cQ9NbwwD^?x^jq2pPB7BB)@C>OUviqPmf?6XXki0zB2)7a2kdS zIzvvMQi92FY+L2iq~&IRK56+M_mbh_bJOuDL+`@Dmh|Q1ESmHp9+hTo zVGak*m7#O%K&@Eljl|#lb>H&(^s?lS?)9x?1|RZ!bTWmL;+b`FB+}`WzOori-on1T zC#5a)u*&|2FGfbAT6ZIJ4fo`_%kHAi6rB3xhh@Z_#V1cJo>(LlY<&eN=or0|S-QU9 z!9#rE9A|+eWwU@*#l{}k=WMPAuKJ>Lzi3|bU|?~pjd{3fx8}WO0B=F7)s%2E+1Zts z&pv)@c!EB?W7j#CKVwfuy&LyCS5rQ#*z(A*(w_0xt zG-Z}?F>+?{_~PQHoJHsOdvwKN?;lq3{Hu2E%yZ|RGnbr|b7#P_XI7SMYw1jg=(<>J*ohK$E194?aE>M%>l$<~U{9u=k(qxH%vpCXPBC_aG?SK-o&;FCt%%o{} zRhX+_j)FHNpaIzAnSF(9!t6CAdI+BaA=stD0SGW)!SB^vjk2N*>j&AhVA;aT!-*PS zwhJ_VVGZ7C*fTtT4N)^u#)T8xg@N7$#vi9|^#ciP3o^%{I1M`smo==RGz0k0W0qrT zL>b!Dd_zWzP%rD?t2NPv5Gm}zu%Th4#`a{djvaTTviyDeELq0nEhM&Cy^PS06UloK z))$sXh`LyJlB&J@;=t0Jh&0F@wPvRiTXDQp-=&e+%sP5tGyVbU{Ascb!1G2S3|Y9| z=@*+8Of3r&9UhI1btsL)Xm@GC{^sOTA%}9r#&r=g0Qi68rV2)p%^VMXSSV zd%Z2R;x7ZUN7YC0o5fg!5B*Gi65|!MM+{&}U)IWcQ73~UVLHzE!a#z64Ej1@{UE$! z^NjU_{C*G+!vEg2ow<9-vP*X${*#!9AvZfoQgZ9C_2Vhfmkh&OnxB6jmWDvN#Q{q& zCa2b6Q^b=F=_Yit3%6T{d$NF5%TWqTp$VHw>^wUDv0V~8~m~YRZ4t*bE6ftrUp$5?M;cz5}AoUrE``l8) zOqfk9sFzrn@X7if9aRV9S>Ja?@kQ)+D(?Bt_WPT9#!H=!i|?bufGi`|(#t(r*7?N1 zs^h?ca^Y7Ixd8z@aD4ThWFHPI`B(Tf1o-&gci_NDxN0AUb`L{%k*hQTGW<(H2@6WH zhs!pS$Yg}wN-36Ve?hPuWStYk5pi`HGELrqWs}`yzb%z{U$Rg;AvynkB z+oA!eKFOil9*qZt9Pjtr1NHmod#!lj(V?j_Xs(Zj7$7(ZLA6W;I6P?<@&}Z#JSz^g z#L|2g#POd-EqNI7h>)QenG$k%$YBI=Q3523GDT=n8w=Z?L|F-9qcd~YMIZ7#jNt?1 zf8f3SDpI&@$RSM@L~T>wG{DygGu)w&Lkg$C4H7&b*Qa;2Hg38;~x!SJUf6f$=}fdjN!6eV?nxk-@7VS8nL@#{B8YR^_ygc z8LZP$SMJx14%Mkcb>#rTv6~dlI1!BBLTCnZlgR0un?epQR;&b0COAAx;b>T9`9HnP zs^nj4J3TS`L&Aj)$xi+f1$-SN`F?V~Q+#4RDMPrQoWDC- z@8?ABoqb8b#38G3vxirQ^#2|b`6uSN^bw-KaT82^>Fz7B}7pCEf&QdZIfxEy-u;D2rp{F{?MhxnwWny=Gsbj9)G-4ShXydKr`fX zaSgCPIbWHRPtij2t%rlA=c%|z2_2>K&nY2J-n-95CI=Jy@+N^LB^@O8%;jpmiXe6o z$A92N?kOAgWWCmi?_C@@(1Xjrq0VP1sZzq)OkanQe@&y_*B|vNg_-bNd;RU{c-%v3 zCi>mT<=gL2P7oP@$p`Qm1pd(N4>36b>_BluUV15xdf)G*rTCm)R8e0cN}`Oo43>ze z@?&br@kaf2dnL588kfd>wxXxvzjz ze9{LMoi05kx_;quWFXF>Z1U9!m!zyZhYQjq>NpYuwPir_2cg{JVB0xu4goh$sGZV& z0g3;6pudW~dV=#`x4FQ6-xTN z6|>qm2*wQgJ<2ij(4?b665+q7)1M1_)MZyx1cLmefbU`bT5~y8&Qd03+Ui+iR*@$FOFVZU`3$X=oyn_MYHdkbhU{n59-K|P*0VJe<;&AIQ zfMM5j(3w5at0i~5Y$9UI$L^D4@ZbpzP=v|C+Ou^kUXIgX8O0w*viE-rC0vn9#1q>8 E1GgL~c>n+a diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-311.pyc deleted file mode 100644 index 1e30807933618c3fb06297c8ebb82e90ccc9a28a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 21453 zcmeHv32+-%mRRF%03-o|0Ph2&NKv39N)#namL++3O0p>Gq9hACFc2FgLE)gfn>rxS zquF>f7{`@xT$&OyV`nfM+e5`l4XbjLl1x&{W~SCjHdWoMUm>a!rj(>yi7T5F9eFAn zS1ObDeglnzlq}cGW;ZD{oA2MR|MlK~$0vVQUT$OHnO*t$hGxGRLVd&%W&Lc_>bL5!z9C|Z+WmGKH%7{$<^J+$ zg})-|@H?VTzcX6tuZ+6l#hn52Cs%9r$aQOOtE6o(P(foq?DZw z#<+MioCscxgcL_`W+rkwAo9VOFcIgYimi|12Bw0sSSX@c2IA3~xR7mgIxYsq(5pKY z^XZTfPJ|TKC^QI0AbK<=hGIfkysa3ABjFi^9gB}aMx0l_)^{r`(0m(yyKiD591B6` z+2eh02F?r(4JzzpNDK&}5C?s?it#|NI6D&wjc4bB28L|nX$&IwnGoaGag1MYWlDcI z2WNO&=QnUp&IF;6tK`fOnjmy>?As>T1u*A+3s=S2AfM%6p8ZzN60%L`xw5y7emhrA z;}sAuqx0=nD*9oTvvVC7>lqy$r+YKoCliF&zo%&(_ARMTGJ$Ln!z2xmf_(K!8X={v zbdpKu#+9o;!=&z^?qg{4NxmHu&~{GJ1UT@UYF#s=a`eC?Es(NkWz4h|=gmp>yZWT{ zuKo??_jE}X>akiqE2OMiib)raE|<+D%@6g(y}AS=ftL1M6^NJR;$lNC#!xyF8}sRD zM=t)*_%Y1DC)z9k24}7eW~VY2FAT+weClI}exfBGnu_}e)QGcFm8-xpKs}aG*k#6O zoC`r{k06}{V!>!g81oqvGfzoIv4xRX0`XYnwqoMM+cTjD&ioxdR z6*e553{aA=tIa0Dk&vL&aUtC40W~i;I~fhdfP}*FnBq|L@?>YC{ov~Ww_*=+Twpv7 zG<8c;9u#|bDP?MBAO+k*aRNID06PF45fG5#d7M7Qq|&DiCm}My*FqlOg20QQ6+t@! z9089D#BWB>2|&QzAiS2+A`J zn^Q)@)@E8Zt+X6mZaMgPx72b_ZaG3)j$lD!##0aVY+AiseY=C$8`c<|wNIC+YLKcr z7mrBWzNNl7+}Tg`>{;nKv)pq=`t~)c=ZxGFAw3b9?MCzs3uC0Id$qA+aqJQIuiE~& zOKLnTH=ZSpXR)*{!Q%;}RlPIrys;SY-%TVWVcV;u}rOlG3SDgbi{ zOAs|?0A#O!Xej=gSdN`npJJSo#^RcDhS)*qnK9e5y}yH=J^R}wy__>74ZtW&qx^2@ z>ZID;C$A&hfW3KDF@|mgBR&i7h7xuE;|}L>R(RxpJhCtzKj$;@J26hl#S1ZD8@Z!a zK!%P~!S`U5HUNn&C0qKL9ArDx7XCVbyUc1;vqL zkUE+=y2dgUuK9ECTuEKYBWl`|J|Z`5Cr#T|n|8=eJ*25;&7^Pb{escg?fw!0G&D2L zI*G02d!ctl`xEyV-JdRqfY7F!({p;5Ykkr%p@;c4j3!Q<$}+Y6-V5V@9o(;!< z$@_FXp1(e$qC2fvkiA1Tj|;)?zvUFtritIT^JO^`}YyFl$c0(tCh2i1^q#zG5LXg5HLyJxRbHx@(~FwK!!Xl% zxvPYE!nlwSjH7%LbYmD8jx`#b354Nj7kpM8iJspC0K^j^9D~z2HXh=UP4nFda3y&} z4yQw$AD>dpD3-^0pFv&StWbl}iDFXe11CSnf|RO2UJ=MjErCQ^$*ONqX9$o8cL3o2 zcds~Gmz}NYF4?)AIJW~e;Mv(C*|)Ab>K58befQrhlN|?%;~*TT<%f09f;ih(oO_m? zdmion#eT`TM|KVq=dk1)UUhFu50dtMk2*fxBe_q>?o-5lDmAp~uDdsq_Q>u|;_gfh z0i~tN)V-yx7jy>Kl9|D^86 z%m;=!GiV`Z%C>EzpChkQ3?Qi{y1YXmltw{p^g_}bpYRH^S5cbtPVn)lSDXq9DrfHY zDmpKS;lW7gbL2OQc_M@f;K-AqgfEWX3%0rj3sqg507N^;Y>7_N#`<(LSU|D&hn5g zVXvm7iPPtCYQKcr8WgT6W~$i=ihitTOwrGXXtkOF4HslKr9zd7Ksm)@N(jIIo8SD# zXQFzgVi7<>3xRrl5&Epy5m$%{z}XdnZwI*mRlnhjK31jjH=r7S4#9bBV5U71;H1Q! zf&L4+ba{MM#BS22f?^SaJQ$t?0gp^d{2Kkc1A!Sn1j7D!NDu-+Auu%o9AMKr4zRw` zCBP-1Gy`X0WmU?YY2B1Er{0p;=1j8}Vs~V=A+xC?WtMG zn-@fxZ3jUL#i;b)Qu|{O1*s9@8A;hP?cIxaNc(Fm?E}m019JNiX&*wYs9O%oY-`5X zEA1X6zM&Q0@nzp}*>{5YPSB4L)|+}HV9>zaO2@uQ0LjV1BZ!G#LB_7tf-l`7A) z?I3NhuC(0p%xyj#4JgU9f6^Leaa1pxRnFyGyzYLo0}CysWBZKvG?sVK&UZWU=Q zcs9`9rYLVKJY@5SCs39~!^L_m$mCkdGtGSkM{Kfy7nTa0RP5D69R3jCD-dv@pV_Cd zH--49NQ2{{PhknP$P1-~Junm-tv=tJAWLGzmLpk}s`q>Ev{9M1ZNHaFzuhWV+17?fJ!A?Zzs=jGp7?5VPD25k z=_q`84D|JO2uq}eVc2(CHVD(jtj1^Kw_=;C2xzrq7;Dw=yos^R2s#mj0r<=+(~4q7 zF*-g3c>=DhIzCk(u?GUcs!$Pt?OnfaduwI0*f;~=F7s;}Q(b${pEiDIlRaCBXDg~< zPf$MJvh3WlG)Q(2K0Zr!ACsJ0WakKRj$}o3FmoBK0bQn{Wu;-)a>K4i?qAeD9wM*4 zA-#D?x;jCw#-*8Aa`na}_!5RCc^SWu3pLNx3K0*WmV0E6#r1mk4^Fr5Y?Y6-s!8c^E{zXf9h2V$f(X)x(Mufij)x5NJ>90c$aN)))=A?50M(~6{hO6VU2;{nLwQg9R4``YEO>kJ$(hT^4=QG%IT?Z#+LQX*s{JEJ9x=- z;F&1{4~-6dGt|qaSo-=02GQLzt}#78hs=U9F??Qv??Nk$=|ViuP7C}gA|5i?hcaEQ#m1hMkD=}-xdu#DYt%5m zU>g`+o{7I$n-9#zhZTA+fw`JS6Nn> zVNI2FX6o92Syp74w^6HorluK^WjQ;4bC;UiKm12sf7ta;xBYUP%)X9U^$UZfaoZwC zx?f*q%Yh?0&gs)4>3nsivv0YxPwpHbodasN3Z;`{d*8g`k0M2?_H-@YggTfY2Vm*%q03p8wsv?dN$Iah>!H#RY9}-AWl&YFl zswz|!GJ z@RjXA1gzZS;C4Th*p}O%=$Hd-DS%{y#MFHt5+4snghLR?)qW3Cs_#PO zV&x+Dx9S$_=6e`Qr3*S>E7{lEB!f=?jr8( zG;hW2TXy>vuSjm6>^?x;2eP^C%kK8YKFQrKySs_IyRfcE>+UA*-6`u6N9{tV z^b#5{7P|M6rPO3Bw25;irMhoEB0bniO_R)=xzKtpR%+b^{!`q);?y8>`TU&a8RJJO zV%-D0%aS*E6k{ti7tOK2k8JCjLxxzB&rPx%lQjK)0Us;ISe()^@4hMKKh8{}|ByZ) z#hBq9gE3s{EwCo7s0xkpZ$rJ$!FKYSgsH1LHgh|%YeeN4UZfP&;Oa%Qzn5Ng@CF4h zlARZfvB+y8A#j>~{=c9-)mk1@C`P!rz$wPb2zWd7aRL0EL2*hkPRGMB#f-EM_aJ@` zKB}0&2{aj^S~;BFGm$VFRcB^G(AoeB7k>{cSqFizhj|diAK2(24fY~>y#k~G{>M;5 z9dAMo-wqr#X)}#@T6LL2*u>U0&3sg<1BF8RN2S=j1lv$|xKY@=4 z{~N$v=83&##ooAVZ(KMh+uMn~eNn$`?~?3YtM$$A`#$Rb@bJ$@{$OOO?$KVkeSm<~ zrGAjq52lWRf>qPBa8a)If&DaP|BNk{D*9!1fUpA+I{?;@vf7RDX6hPJZ>*IuPFKpf zT3P)=Vc|8orjyik%9UG4<(4%@Z{1Hps&BQa<^A(Nx%{Kci?1zRe|Si4+E1GHKiWl_ z4$a%Zb8_GQp8bP?v~BUV#WxmTBi_Ao<37^3PjQTyo;4+auU3eY>0`=iaHY{hz zvZ_|l`SI9W_kakiF)$=N zwH5ah*kXQ=aA~Zhlr6lxHC>TIm}E=pS;hKX8Mrky7iYR)e)8o>YjJB_rD=_sPrGt8 zi*hwZ3dNG55_hgXmT-2}iY@5za+|m{S60#{{|z<)0%vi5=4{Vwr}ZMaQ9C}|0p zyG#&pPN=61s;G^&^QRkf^^@gECwQVNQH6Qv`WQrvPc$)OuBr%c)i=Uhk*s_sUiWW- z*O7ER6R#&(m2@XP6L=HKJ^Iz~g%fAJXO5gWJ2L1!b8=|l=7jwK6Q?hD zPYipH4~;=iAAqBy-k`U0hiU-W(bauB7>#s#hmVd7eP!$hC4&DcOc|A09vG;Go|gi? z>K#KL@i<)1go{6Lmj|z4dQm;_qVA(gG2mA2L_D2%WjQEz1C}MksL2v&$d841P_hhD z@tYvR#Q1ly!f^max*MJbRChZGksrcoN42-MIh_dU+>;lkH$Jt#4|nK_#b(go;RF5qmY= z``&D5i09)x{~@;j8wipBl&aCVcodZfyw)7z)H;73i|=5uIVj+@Y5s3x4pjq8(Gd7s z`TLkLh5(Ogs7F-@xF1LF9E0C?3N!$)6X1EK*i=&>yB6~9mDNV6t=?%UOfAIicL?y(Y?GnqExuRJ#XvRt}!6=+G+_3Ne9SHS4lHh@M)fO1)L z>gfE<`J11WwMgExa@jdjc1|ig2g*=&{fcMHvS-WE?!SBB(+MdgNS-aSMwwm1i9lF={T0M%^#B4 zwpEw=-j;<|)6+|nkGmz;DcN<3xK5$Tr*6g7vFz$tY?fRdvTGl4?MqpH%`(+>3pGEu zbnjBiwhI2SaM~tUZza`RQ}$KX_O5k>ty^a67Pf+?oUofEcJmWQ6>+o?$Mz*7**QRV zj>wLYPepR#Dmf969T92f`m*D?lzUe_;QA(Y+eqEsM_qE=pzIkUo+0>DRrS0`YPv{G z@6xnf^DVimk5u)+ZC!VBx?Xl~U7V8LucWLn5w$mU8nb>h`8Y}rjgdnarAw32lpuu# zfaH=$4$VR(x#k9`xsfVc)8~0$A-{Wc`0;Ua;50dKLGp*Bi5cno9VnjD4e4RccnBah z2?4C#Mcpt42M=c%DpUZr3ON8lSIlu6;3<#JKfX*3o*@VQ(!~jB^18$)2K_$8REmHkf$_}@PYnm4(N%dw@y?yc8qpgqjKiW#Vhe`F}ls)5aSU61F?F5E@ z?2+y9{!gdmfpgNnb7b3j;(im1Xr9J}^Tg9ZJlhs8KB|9wM1FNddi6NGWcw)boJiT8 z;N8MzV((gGJl3dTP46go{<@uM+PY+U-2TgM>0*FfjFF4i$;AYW^{Dx$d0MEH?TW~87#8eFf#IZGf;A#TI3lwum&B0}Eu>VuHJ&KJF^x%ojdts=N z=Yf&`3=ov&s=gbrAc=6@k1yeQ*lo({Q~~KNNG8g2tc`k%(2n!;2mrAA_U>Kj9a!!i z_;m1(N2KvdGCl({)vu@9yB{`y)H?u^w)-4zZh$hoGgUr+C%wVkJusq^8Jx;!UZWn9 z2L{NXHv9)HKZCi~5%36hA)u!tJp<_|v8$#2ViEHY{AKtR@Ep3U{jBLNCiL8_gRU}D z+Zt00n(NUgTnr8zGVNPqib20#kD1Q^s)&I78HZo($u0VC=fE|eJ=Xuyf4i}z7?pxg zwE4!H!9ZVZF5|`8JllV32!*+X8>WAq-QEO?^N*TAWENJS zHW1nK{iE;~S8yLK^8)_Sv-&hgHt$>6{H^89-;y`?lg<4=tq#hj0pd9uSUEAdd}2~Q z5hf?Xg@w}1J3zv>bVt!WdP{c--J=))b>GsRLGs3+umzM@Dcu8v{1QkhK{3FndBF72 zI|_am@N#H75~f6Z{>;6j{P*Bng`QEKs>NHgTQX-e8E}BJbmfK%NqM%dhzJq_aHRV+j_#c;RsOWFpa%bL7EJd@(;>_Lc)gg*zck)w=9F(u(C+2N{Yig!SS zub{lpId)dIZzA?hX@1%6lkC1y=h%L^eVDWl%k_sz{o&NH&#D_sTw{GQ+fUejiR~|R zjg?@fuCXVM%0EZP*pruzF>n~Ybd16Ff9V*5Nnl<&#tNNTIak|D$Jk597&ylGM?fL` zKS1zP1oRO9hZy@u2!4S8ZLa)3LGb70gnF+3VRKdnC@>vw*0235@PXff zhad$%W^1c=)Z}~@4sMudItRo#K1mJ&s?`Kh9g2Zda8M}Ht}eJQf$$9EEraEaViz3ngQNu^}qWG^uK?p|1ks&RqDSi zf;(#rZ%0su0Ph|1c;}DD3w`SUGOC0a{0%lmPd}~x>zFEe z7mN(L4vA?me8Ewwi=*lHFF4FFUD8JH6Q)%ve%FkGP6vOUw=tYQzYzI?mVH@RBVtJC kbl?)&7=A(Xzbq_$%OKJQ8^bT)CV+0=mlUq8gOvXM4@WumuK)l5 diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-39.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_T1.cpython-39.pyc deleted file mode 100644 index 5dde0427490f28e2c88b7a9f8b82cd737ec17000..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10546 zcmeHNTZ|l6TCQt%b#+fq&&BrGak3R>iPPEaI2VM1v9cLoCUMr8WbMg9iiDy(r)H+d zeaoqu*wb6Bl`;vttT62Iu*&WxFa`-UfW!k5;vyiAh(}(K3XqU^$}PW|UQ|Ns4`C!CscHGH=J;g2`|2OWkeXR`-Hmz};vTTaNFvN`6Uad$Z}b%6?g;t!AY)zdG5=WWKL0+IFEo$0?)UH4 zwdXWZ6wce4a3UwF?ppo>C@BdSCGM`_e-62_s32EC&J+3-tv>Z3-bY)h8(Hyun1peZ zWG)kzI-!VUUC*Xg8dt-sQS)-73MyHNRTo>~MwC@9hi%bmHG1LIW|U2b-EMOyNMzWK z*E_P6xl2Nv+YH<7sF@Ydbz0p{Jm_<|lY~k1vyW$m%Te6uMcM2MI)qImFSe7Y9XFDl z%vx?Xx|y@uSw%rd=1*I?(TG*K%V#gGuQ%EeChxqs^b5h23m493&PJ33aTE#6y_9r< zr;_b%Gg{+Y-uWpZ^DM5o3lV9)E;Ov;f-r^kw(gtas3?dcQkFO-98p5b7WWBVxbNvo zhc;GsN>tz0{GvE6rbP`c95EwiQBsOr)oxDAzis(t zyQ>NNrVhmMDRV*tdWtL1veoK{2={W6I0@QeD~ea^W>%02GQQg&Sb|Qwxs%x<+3806 z=)@(GI$o_l&vm*MwXygza~iFUKtZCMcUx~Xqc}S%BCd3hmxbFKt*8wMH9GBVIxjnb zovr2}4`b}C910Pvb%3cGN#?}SjUZ2PLMnBCE4&s3Z)`^~2MYk*>yboNtJ?%Zv#P2& z$XC!{l#^Vd%xHJB;(D_aCQm(?Rr1LohuxFau$=;I2kem`CgSBRd)PTXT`HluE)Sth z-bZnq;(iK`f^Q%nq&PwG5X8PFzlhuIGpD-onhcY*%~P>BwRvhYiZ;XRr_Nrw_;T># z<%_4fva=PfC8tEbr-yeSZLzzPIYH29H(_nxXk!`3WDn332q3N#fDD^nH^!3r%-$mt ziuzdC-la0CP^qkjgn+VyfmbzBbKSt}nk&7pd}RRA)84C{Z7-CO7kcezTXveAjYgb! zz?s+XT=#Z{we7IE(`&RhywxYZcs*)tY$lQLfDTWAG+yi#mlWJjs}~)rgQ)Xuh5O6| z+$UA>dDQ3Z^01mgBgg0*D=@1Wp2|0d5ilR@b;m>trgM zW3FU{?M55Bt**-jw5wa0emN_0pP^V1)A9=}H`MxO7B|7fTV$ntTf*inp=qoDtoIfs z-ec5gY`IMI7+L2PWKwNUOZ8hsIA*LY?hA=Z1LuZ(oTXc^+FLg5wXP+wpjB?j`% zzL6C7oYYYLx!Bf$+cRh25Y5a}?K0vW`OeIf%?{#EmvAN_#DmTn70l1z%bg zS#KvDxwe@V=t?`XZsv*@7-Cw~%vKnf)XH3N;Jyd8Hn@NR<0=GU@9Gu3hOevlPE6?P z9$nX{rQwGY8?EmiGDEwdY=RI>2M0c)l$i&lvV7y>Im4IrFho{Ri6nX1@SmT2PzLUHOm>Dl_3X zdXIV+;38Tu;vTYIXWfgpuhPbP>$21GlFde(@2f@cw(h|xhRx_h!pbkKN6cU|ZA88L z!7#{v(lB*9Gn-L6_hTlJ{ZklCJ`4f#rYx7o0_7miGPMkmSp-H_$Q{|_`p(|DNc4X} zK>WtsZkoP;#iiNLC37{s_uwZib3VcvS9SQIo#m>lUHnH#G_=0e&{9k2XOPC-7~IXl z-8zH4!}_D(7f55*V`qSU2?u+_PT?PpYJ2)E)pyXw7Url7sZl?3)US{hwhYy)0ItNe zNy*W-;9ErLmeDVQDOtiA^tZ6{+?A}-ZQKBhffIqb%}R}ScRL9hBF?H|g6*VOLMx1~ zWd%hu;>=imVl^|m2`O|JbOQ!AtLF9%Oo96^ZoK!=M<3N~<>a$s4BHxk>%0i^6AR{0)Q=RRE(^OsKtO#=l7CdRW zv-u$C$_QR^EsEnHjDyYfAb1Uf#vTMsxs!hluKc_{Z`ObTkb;javtslfx(f++kG94Q zGyE{k9iXs4N;}+*SG0W{A%}bpWs2{>k}E*d#pal$*UgGSfE^G75}BciK77|KR1-KA zyU`ZWjWJ72+sI4JAVnepl7Io&TKE&9L0OiR*n?$O_S=>(qns8Rb_nHcRb&cxN0pd2 zt=Hh3`&N5ltfd8hY~(d4FVT4eUxggSKv%JM$@K=Ge8o2t8m;N?h}S!8TuU+!vl=NsP4*>>_dK*H+z;T#XXfl z`7*O7C?2BlA?k&EkzQxP_;B*)SvIkHAZqq8}+#pv8%=`F!iXhV8FT@yK6N^ zUFxM3^;up)kfF$49iYLV$&x9OScV!wkhwvC5M;Z_bR__p51V<7JjH2|yGgPZd79!` z3iWz?Dfv7FH-Y>mii3AhDZ)_BLy(ol9KFl}lj2y?DEKM&> zumb#5Bw$(@I0D2792v|RDtt5r%Rx!@>|5GvEN4*1z|l!c3i5?L_m(zry~RPA6voNZ z?-zC%3Ni?^S2B0t3$H~xy+fRs_sHrKb%oSb1_-CU+pfxeHEbeX&T9~Jo%Z(2xmm0t zN?UJi1cQhrD>4ui4Mc>kC<#@nI(Q0`qN_8xy+8+%rvt9&0$;dFD6K_}_48y~l=1VHgZ|7;pu-u8U`yowcwTKZC@u z8_oiZt2^o>07)yv`OYd6LUy|QC8?I?WsC|ktGpUqt;if6Ed*=vo1jRYwX6&$n+#oc zfj!${R+bUNZ{-yB2ZV{sn^oduTkHgJh!{VTTnqUs%N<}U;JmU@4!L`T4YkQv`ZjvT zjOxjbtGa8th?pz<8oh;&)591t=opC(g^cebks5pOMz??pEXIjM-$M+vrLZ8(TZ|H{ z#7-G0TS-CSu8;RP4}NTK03%5;aoBEGKd-&M*e|M4O2Aj~sK(KOBPYc+Ps%(cO6_-0 zvWw^H+8e7$Wp4`RqJXm=tHEcM(h@z`ihKhz{19~Jqn`cfV!ON3d-77g${yb~r@$UV z3{M>dd135Pryl4rVbY8cgnsy+=&zj3el@djup=^Sqlw_%=)?%)!(=nFwmOY=R^YeQRzHZZCS!G?7M1*gTkrfCzulXD`=4~ zI2|3W7S0cEkYo53ZA+#%{Glm>kr|#`asHFQuMUBrMvs3RQ=#0 zJZb?RuXlJ740*>fD3FGianVr3Ph+2g2y5G#oJ|0TUHvV+y$En%>9Bkz!0|MUh$1lp zMgs=G1o=RK)OHfk4h0N^oM;1>puPwqg5^v}odFF{Af!RKVMsUsX7iOiHRMH~A=Oj~m)fA)jOUcyM^jF|ogHg&D4k2^ z(?jb#9y_#hd->IuUh%Fx|I#a$&U;s0zHsj1@~hsGcmCq#3+Gl}y8Np5(z5sBg;kU+ zL0nw%Lht0M+y|d}ba5waHBWlW7cX77Z7ssi%5MRZYVYrxxsx32rQ1)fGWK1=Aux`k zaKglcUXK>kqs`B40m9{z9Ooo|7zmR^jPgB>I2wiQy&Xw78gsLA9fqJSud~4t#9iKt zlztAEI*n7_YqqZ^aHo65H+c|a4Ez@zIiBBHU68UP%Edcl^*{B1yfmnF)U z+pP$3tbB_F&rzJr+bKVQGemW;j(B<#W*nY;gSBo>2Kftu?~n!6b7f-IP5b#crcs-n4i}l(I2r6>&NtChN~O1 zCv<*$bNARFV-Yw5lF@n4!^ptM@k019a9ujw7S2E!Pv;Rc4cNYcIPVThPWA6w^3QOa zEX|2L7(s#qHF>YsKiEsx(Eo?HV)9>2Kd$$thA*IE2^_uRgUFdIN0Iokw8=_+W#5qj z>wlRdq)>6DLb}4UnkJuh2WvQ(#8p(#?!{z&&~nM0b{}w;B+J)ujeS1dpt11Em(TvV z#EfRt5WRU!=&o2%3aCIv6zddS3Z=gTB~s%B95CH0D4WjT@*pNNQ6T>YJ;vBp5d~yg z=WxXY7@_(9CM1(Q6r0ovz1Ix1l6$W&4D>Lhs-F~w<=z%uHMs+^D!;~LpW+>g3H1}} z(^UV*gg!?;=FwvdK>TwN)B^nYM?N25s)`5o97H~i#^Oe=_rx&x!TbKtVxPu;3}YYp zYaH-hiV5w~h&({;GCEfgg)xXIj8Q~k^bSqv_$L9yPXdb1Frbj{6RE#N@hyrEC>Zz3 zJ&IctzeDl66rT=x|72FYiY<~`x`zi_;SP55e?ik<8BXXkr0HKBJ|m~;@(-wOl5hEk z6n{#g%*lxKDc1Be==P(Mz1h39{AqAK`RAPTUr_wzX!Zx5o;>P0XFuX>mBBp5Dcc9^ zOD7M?c90%WHc7I3hnYPX{;yGCb9j2A0UKKY8)JF>K@~e@c6+CgXDE>{0Zl>d@XhnWAj zN{K1XD*wNg3R6=QRf=hf8iZfYM^lf^|NDVX*)QjF$%`2+t6wD)%TzCSFKuMbnf#ZO dXP}YX{S1B*)E%Q{R`jR!no%)k->iQ9zX3oEnLq#l diff --git a/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-311.pyc b/BLIP_MRI/project/dataset/__pycache__/dataset_rsfMRI.cpython-311.pyc deleted file mode 100644 index 342a18fea5f9c14a7834b026d19c4f272bcc5c0a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 23851 zcmeHv32+-%nqD_9;wHdLJVX&Bb&!-q9oAvX@*#<$b<);pi_#c|*q{Un0O@Yfl4x+O zalE^*qs`D>&x-cSV@%@lP_eS%RCW_38E4D$>GpmH(yj4|gtOn7PjmIV*}I;8FT4Bt zeeCY<_v3DhS0)1efqdCue-LT*cvYghzdBLVUz4cquT9kT*YTW%dxjGnZ*W3I)cUcN z<37Q^9{LxtJSXy8(M2=z>RFx}d7iSo2A1bVp06ygk>&Z3S6PXTMdjyutO`1h-&&gYUNsc04&~^7|`<3P7iT z$?kUvE z^lqM?6X*rnH^zi`pio-i@TshQcd6r7xz3uUqsddL(^^nUi7_D>J{L=!jwQp$Mwv)rI35iPBVr^K8%k-=9NZHo{0`1EjEzSY%Hfa5jh(T zqZ1O<9gZZWbLd} z&Tgu;Y2vWj+%2=~K6=c}zJ$*p=Q5UCmWq@GGa~kI->D;H9J7pBQ^v9+p{z33?EBfYfrOlTxH- zsft1=P$PqiaG-N(o0&G4X;9i~qEVPgbLpNLbac#-Y!|HPMM)d@F!(g1=s;;*=3TYC zp0op>`lYdoltIHd?u;>ttg|?tVhkU z=K!1W%sCkIn0n#P_rjaDr#*tF!5Hi)=4#BIe9lzbn=c(o+XagrAJgCTmAr}eQJ23Y zh1CU2w5PpW(L!>Wsn*JTt&P7)t!X>?G`8V&Z;cOuy^@|pq)?pJxwHx3B6pADn))oIaB8F6)ScZ4pkJkVdYsV-k-GAm z3{qT!7unFnb<~fNA&DE1@Q`EapgM|U>0x6jz}d!9$fL3JSlC#4!J~T5MPsKJ?6=es` zYMw$~G?Fw5N-06}=><^*mFuk3YNgX7gM;8cQZan~!pGAbA~@us1XvBmn`k!rhTuOU z$x~60wq?ySBx$x`@Y$N}%up<;*;7N}03HYlG}lN{0vbd|qnaZ!JRAknVjYO2w3=9Q zcqA2;fEcmlsc<5iIz1$G`q*AnDag&F#ZwbY_JIg-$yC5^wJc^JA4(%vNj zP1zpij3rKm+3VH|;h5%Tlcp!++McwhRT zUd21_EEuE!X7*B#iy&KoOlzA7I-qda>? zK0Aa$aeg~h^~4TV$BukCs-uV1(Sthfxw*Pl>S!YanVq*6HN1Q7a{8TgcHh)8WznPR zqDM2kZdWaOw`<~sshx7w7Nu&7TD9d?RnK%)&y530)iY|S+Z(*Jdpt66K=!sM-WJu{lHu>rv@D%;OnbUzPxl>fAk*_%L(`Q-A1;1>aW*bL z(xWu&R2z0?_S|k*a-}i5=lT)3VTaPNLv7fR*^{fRRx3LucTHDrkSjOjYHBCk*{y2* zTBT;~I6uzkg0*U}OYVMhI{2gLLoO9iC-ASItd+Kw_2wX zQ6U7TITTAoPDQ2e(TZ)1o8Q8yTBlWCMs~*0YgRDwc_uy2%Pm3IMmtLh_PnA&4B=K*pZ)FUnYRl{Mo7nNiuMjnV%(SwlBG@HX{Tc`ixIOL zF&o7&PbgR1Cj~o8<+8{t4{3|w5Gt-Y&3qL&MC`@IIyG$dAqj zoObtV-r?B!XgsV-4w_4+?lD2q{E#%FNhWb7BGOsS5hn$VgnfmHeTOy6aEg>T!x2FM zhHHN4T1LbqB+sNYBucR>Uw!per(JWUhR#Nlq=bR^qHic-7`@j#RBKcS560qg^a6Yb zD;bUnozDC+V2ebvL6?I)Qx|(R2kW7vx%GQEbv~u}hDK7fDX|J5CSlt-6-`MJvE7AD zLEjTI2H^ zL+a{7tTELY%0iUxy64~;nm_cv@6Wz6mAct8;aBPpsPzXjzFgZHwe7K6ZQG{Xwkd5p z)V3WN9|Y)JRf}5Hd8=yUbk)XOpji#9nA~vfNp)rKbfEX9?Q^T8+W!Rykf}h;%U9EA zRA>DJnrgdKS)1{dlTSigMmo9>F)O%$uEs((3mu0zk(e-`krm1nkz~=UJp1?(y`fV* zts+dS7ksB`^LJv25tNex0MIW4Af&##`)&Wbe%ZN1?;=t4LNK}*LW;SI&|aazqBiSw z{;H*KGwqI&?n>fI;|M<+hEAf0OFpBs(|hqbQ{a4>bz$94Zyh1LXG>v568E z1DJOfk}gJTP2DSKrcTitX#&Lj>mSTZ`ZD6zpf7WW!-*kb1Po|c^M;|Pip2FCrWRuQ z2k{uaNxPT$GJz2SuM(hkHD6vQH4u%*!(m{+Qlv5rcv{zCkOV6dAW2BN2Jq{Pm@ED7 zo~6Qml)uN#0jsR`W^hLZ;11bq?{Q|(!&~jdK^1_v#$I=iGe9?QZ-FkP035u1GdskXrOuMp-BT>!L-g!kZ_k~KeTTB+O;NL zhW2sN58U4ZK>B9OJR0ZX;uQFz9+@ehEPZ)FQlu58f6rFL^0;$R|2l-Yqkh&4L+_VF z+I_+oe5jnW{txY>`2z`&oo%=#fZB97u5vVo;H}e>Hp~-KA6ZjKDuiOEvxRvqWU+f1 zn*14x1zV|FDLq3`0`<)gnWoM@=70joQQR(BY6Vm?CG{`Zx-IRP(YmWt$yBQTFDlh% zs$|S0NwAD!PR876H|VLO&>2b-n7V4rQ>w%lY-v)C7>Xw|qlM8fXUYyo+LiW})`!}i z>54(;m@jS3mZ-%F?SawrnfnFB(TYnSq7M3EEQp&^iwbe5!3gnLAUjIhI#YROt}Sb* zoeFQ!9SoP%JfF~$8tSLC73-j)PqRXFyi3?1QdYT^!G=_(N+Cj&SEabec%TYb{M=Bd zz?gHDze`hG!2b1!Yu41ra4BimRSxZWyf0A8*_0+1@%nwDxfeOPYrB3;4NN{V`SP_* zMgm&`cfmA$b+uC5iNazpfoA|TR}_LXOk^-hNc3Dmvb+!znN%(IP&pcbxR)~RDKP?R z-jW>F>=99nT+k}SsB}6q9Mya}XO!pmI{j?dDsBmE1J;-X&WsU!3_P1|kI<}A^t|TK z-xvF+vFE6acyxrKI-KR9JQybTT?7KP`fTDU^=3K5}S@x9L z(2b`Z@b`uW}zKC4QuD zChp+={{nbW6Gf@aM$&E;b6*|=$ijkEunG2$vA29e`wI(We2P|8J`^pFMX4lflr5B^ zJJuY?EVRn#Z_$!XYjMu@(KhvAX8UQfDLebvY8N-7YJKtleZ=n4WL$NK|A^v5%k)At zWp{~K-6iS>ZetbQ99qR7%xREbAksv|#0b{fE@WUGLH}0Z{~~eC3Q^^%E~txVs6e2o zipno6Y8RNUiyYPAFcxG|8YJ}|lW4#QHWYTtvaG*{cclaX7C=+z!zbQ<;=QM?JTvMtH%f5Gf6Ok;RZOv}WMwCE@8tBN_ZwG4Mb*h2(TY(kRVCL=en(gwQexRHZ zIIaecL-wERP3QwMUPgR+%?2F$ur!kU{~iG7R3dZ=d|u7HvDAQ*HX(zVUJI5H=xg_I zuM@eozEX&n7^|VpwxJCJo$ze`zy_57^xQDigiP%_^6l3mmHNm6-l&+n1}N+uOpya* z`60?5hSXt8zx8v?Nl4Go6!#4j``4kbdc_DZ_mraPqZA*g3g zncl(Uyzg|{#ShUgBv*+f+6YX_psKLuwIEdt0bH%fixEuJ);y4`!vjN*GR~)1m&JcU zCC?Iw14y*sfoBwST&pU@HF^!GR$q>4dd^!^#RmYuG*9KF$0wFeds<{q%kApsiRh0H zDb<~7b?4+^xq6*(<*MuD#ocms@ygXTs&$=mb&Ref$J8T(`FM3K)03-j9`DceW_okJ z;HBdeJEwiEvaj_)UIB$0KblF)&L*8REG2WaY8hTW3Q}dp%XEnX>IBM>-E|B&H0Brs z2M1#qH#>Pr+Xy>gjsy06VqhS6Ep86((rxBisALfj#b&%`)Pmj`0fsy<8v=LDVs44? zn4U3jvT>d_t=DE4KMQZ+o!N=38$6&FMl;M-36b@sePD1LeTQK|{u#O^eoBCr1jDah z2(v~qq(-7~MVtdUB`IAO=zdBu_V!;;jP~xUd7fevag7g)gkL(lKzxBn>=$Dk*iz6N zH8&svcB|%g-N|BT>0&z0x$OHVwYH4>t1U0sAX?e)Vqm&?H}C3rv$9)L2*P^etB_V)E%b&21_NW|9&TqHp96%&+2HsNfpnehBQ z%6J{1R91e25(%)0ynqNqu)^Ox3vX{H;k67n0`{7FoB@{FNlZ3?-(_EXk28ZhFG3NVYb>!YR(VfSMynIW(oV8Gar z@ZiQjBUayvv6u;jZ^c*;<;dZ8T*z*T$EGint<`4%L7U~XS?Uvm<;_{KL!FO5zdc}5$pR^=Eto%lWL-7VG7yps~qg-=oa9>4<(Lf&e3}KA$P5e0l z#xVSXV*i%F>j0&a%1ufn@MFq&a7kqmNh$_dZwM(&T=9EJQZa+Y^9d1q7GwxFv2wyUpr&$ zVswmYK_eV3__82W2s0E=kX}f(V1XIopjsvA1ug;W*REf`IlMl+_O~U;v@*1xPmuWv zJrm7>OlM$*rA>U73egtK3~eGQmS!`uiGNHPmkCS|c#pvQ0G*Yy+0;ZsDiVK5C4WeO zgdsynV)He(XhhS@^t?q-Vnp+q#!l8BX(Nz(KG6sgPpGpf9Jxq10+>xWy4|#7BCRy7 zQkzyyo=}<|g9WX`tY!#H4{28WI}w(?o>^@UVabX4WGbAam{T*${Eh3I-LvbPKSmcm z!M`3%COc2soh?;36$-(DbqF~Sq*qhEE7LMa1tpxh|{ zi5U6f>daeYn3L0JrM@K_9Fi4?BoSwY6&fB;!>8p~QVkC&&J$2qZ=G12eRb;D8~yTu z=jCH3a^Ao#?~-XR*5Ksv>uYXqQM^Z0?@_p8H!c3~vG*T)@9`^-XJ5YZBpCrPMc#f;_O>hDL#p=>Jl^Me^G;`1 zu46;4ts}RjGuP6VTfXw1uOd(dgF6GRD#kd$ZuMQ7ht~*PBoHLfNPu`m@p}YjvVZh7 z*(d!Uh)72P9!ybPOYENkx-0GI&U`@U?WnccdAqg9S(vxOQD7VbW*lAx5Z12Qdqxt8 z3nm`VftqLozP}ydA>6L4y4w}4*jWA;FFU4c>x0L}P>PO@urnM|Ec^R>UY|FgcG}On zW-x)Zgl_B5CCsnc7wb{k`6&5f(S52;zn4NbbtZi_@iL4U4l-hhQ+?J_K8j8i7oDTb zCSdS8PmkHa(SYA959G&flKO3C4u@8yU|pyLTDKMI9-L7Fc%*H}Hza)VK+52XN>A(2 zmo6PsDMyDxn;3VIFEu|c;M!#HfrTE@9bOGHT*bgu89ZO!wTWpHaF7?iILs@^gmv^e zh&0{T4=%71Vm4BD!IcRI%5?QFF%U59>5fcr_AXZix95>&J{GE#Fo~2AQ5t_%qR+?3YzhBUIiu?h z)J-Dp1r8`!H~#WOjoPqM39eFutH|$W4?IQ*X10OCet+*df9 zWskaL50aFwy((bqKDBkv6U9IP>QzVLZ@tBT$^JHFCZ*Gad!U zXo8cc=V0gO*NqfJ&@B8mI!)Pv|-O~?>2Ku zV~%gQz6x+G)-;5mG<l*sop+t68`E-31Tg>HpRb8^)I{SUpehxsrXl`{?!@l?O@%z%@Z$YgV_yAutN=Y z+zNJ22fLNvS~ZA+5I9TF)}9M3#X<92-J;y0#@v!+#83E%pCI7!GnB;0=j10wXb$FP zD%eaDsG*TY5t+eMJVw^P#{%PjKxE&-(`bc$z`d))LiZS@pV%fzaF2$AUL=L% zmm-|`Vx#PLSVEDw7>x)QLITcVv(uR~j$$JtD&d_}D0-d_aCQ_=hta9cSJri!mLRb7 zK|RP9G`+XySJ{0078&Vgs;q*#j%p{mQuGkl-NJ1&x>?m#`%zl64~t}0dPvpP6>5mN zaT)s!yuxvy2`BY5m+t3Hn;w&spQFNA0!(855n>X3cE+(4K`HMxV;UaEkfmJ!dEV`_ zQeST`si^?+zTPW0U$0PCKCO5=$k}`R+{A@aL%wbD$kb!kPbp2kYEy4!AF>*oKRA5l z*pJ1J&iy3)57SfoZZs$>_NpuPD$V=U=6y;7&Kd8Y)dzfGWmVn>{O8BzgD2z{1$i(g zH=mi;ulI1{QfAv+T#Z02v`Ctku=8gvvNgD44O^1fAQ z{Vo(*-~s7tFi)D;uW;TY_J06$#~PL~Nx8owKrEwp2cWY|H}wFHl@Fww3KlkdgGVH+ zkp2pdVg)YbOgTe2)uniksNN%AQ#tjd+q@#g2MFVUK+R@T{ zK<@%I^bN6w)?M4ne=ijx{V=F`B`S`By z5pKVY52r6)!{<@_4@jHePtj+pJVCUj6(G;e%zB*7uXby76g zl>dwQn@1Oga$ozL$8W{%1_6u&2o{9G38p7fhEN89$`Bfh&H!`#c$^z&ojfytK!whh z{2qRU%>-SAl)(~#d(53&Z|YCs$J18)#{GfVJF`D@T6uQC9s0qAt%F{{$Fywzxvhiv z*-ZQ_W~bvWX)eW|5Fnvc=dc{Kj9D6Ess1-IG#h@y@~rqDsSHQi#A0>2X0s9Ie_|xe zRC>GU1yU&&?SLd1;4i8kIr>HH8D8tSt@&$ zer7%*$jd9Akar$o-lRv>0DjWyh5U7_AL?nkj-SY1-+4)X@g*5Bax#BCe^SLS3T?}5 zyMv=$;jClYw_Nru|63;q4PePR#F>qdwtWsGWnMYBiYW(MnR2jwHszo=D>ui!C4m@C z@`H18b;QjXU~MBQ`sM@rhCmHk@lq|DRoIMt0`PbD`Jts)saNk6cI^QI^1As&s0NpTxxxgczdsiC2` zBociPzof)k=`1VKZPHkXY#D$eaw2JlM4~P1hcnnI9Oi7tOs?YJPzEuj`YIq9T_4&q z#J+TH)qm0BNwFFu#l=a#0=#GCdHxRPy2$==ocAL8%W)g!*)DuhzDDM%^()8K%jPS` zRmp`b$Dw5Y%5naS>@UapF0#KnT$60Ra-8QP`@6%vC@*x~vmN1i{9MQEaD2t*Ea!{T z{FG$_kJC-%AhY*#O8ugkHOxOrMax0v@y{vsi(*!JdkXktOO*RdscB^{uLN&%M>G6$)7mf4v`lf9aJ4P5T!XcK%r?JcAVc2ZX12 zdP}Ruzj{?i*=QMUvud`js-?&EcGW?>*-Ew3)pSyqsb=NgTs1Ghg=#^5i`622t=3Gt zR4pZK4^$7J&Th@N4^|Jh=c;qb~FKd=o&EZ84M!5yN zUh+z`xw{J9Xb}`JJV3IXV!8s_FUe)n>Td?6|m$HnsLz%Xht= zs)fyNCz?&#T=K)Kz8^O-+qG+vt-8IA7di3uD1+NW*TWi!EYFY1ndM4#3Bb2}#g)JVZotAH!OMtTt`I0-S(H?&Rdy!P7M ztH#h6njvO;U<|M~b?tI`z_~X#gIoGy#fYR z*;i_;B(I4MBg>()ncnC{rNgLW&bDBCM z+eZua4I?x|%QFYoy0NYgxPtAY+NNoR^)~?!`NFeo^XksK15&%vEEUm+Sd??TChU zg}5Ew%O2#aCw{Mdy5`q9%SxUmCFyln7vv%TcX?dEm2Si$E;2n0l(J-(_%*Fl}zK|gPljSdVd*zz)%e8XH@2PI9O9NHL zOqDy`tL5vXtDRcwdLQ~Nzx>!ESN-PFa_DCw)w!rge`#d6%jK27hfGr&S}^gJr=O*%2N)*5 zX7X#D#s25@c~Cdx0j>jsOvFpSZ35zK&zKj&z=&GRk?(N&RD|VC#yQ_YOObzBN-a~2; zdWv%F)%V;+v(>^V09ew}ZF-e-GJ`U|kp;pBn27m7#H};k7U-bbBH9p>zPV ziJi9Ohr6densePX<-_#Xeb9tj;4U{@_p=yFP=?U33exdEBabaHr_Y((7>-`NZJBwa zU*5Sw_UmeDFA!kIdx4Ha$f?%X!IRYkC_;56n3tMIeoK!{yP&9u6R3?+F1eZO)=&u8 zm!|g*P7gw^*X0_o%J8*TZ8*13WxEz;TzFNnJ_UGOgQVh?XK5@k#M8D3kF(i#aD#% z34c;t7J?J%b>`|6%M^TMR7kjH-EXyAcgLv8Y;uBvQ+^OBAP%PWd?syYkEfmV;k21H z2p){>6}*g>g$Lt!IWLczlz1pUwrJ$qmNFlNvfRS zCB6w?gucm&zSAe+Dv|ZMWhkzmSS6&Bs!Va5LSB$mv4b?yeNmn_p~IuMK@+WrM1`C~ z3jPKHw9*Xm=na?#PX{(xp*aAZ1UtkHa*n{2KCBIMVal@xq=9MD&et9gqUQ`UgFIWIZF-Puq=$upxse$Zgl4W> zK8kBu^@A|Cktf-_sSSj(f|M43%%-?p#eoUUrWYd+RMh1DT#9-Y>PII+s9%1l z<7KsZEunq-HVL|D6w3ZI@{t+#)^@BW(K89v8*E6s-9M3Nd9=3oBX`?C-rHL<>IvTU za}-ZPL>V811I{=da+Xu%KweWRvJv$-#UTo=KF}Xk0}C=bYmr@3s&+j}DL+`Qt@%+Q zhS$VoREjZnP4*}eEfVG;uAkln@WwarfULX8;^X4Ij1hw~u)|gOfg%+NHRf1Ln213${vMAJPbQ-UbHRQJW zz%56;#e9aBXe2~zb9Zc$Ou!bL1p27u8NhivVpu5(JOqkBC znDeMFvO#IW2ID!}rM&2%>tc30L4zDyY^>|9oCgub^2l3&E(wET(G?tR)NS==~#m zr!2l{YT*-(xz=R9t|3ER53uZCO*1wz#4HRQ0xx`VP|<-!J;PVG0JRz*n|?MS*Hjlc zCxbL%1Mq=8<82Foog3ymvq?QR&)Y^^hZ-B;n;QUNv2_ZA0{Dot_%h-e?_mh)4=6Z^ zqLJspKE~DH0bJP=EnyQ8WnwG+Tg=Jbf5_bH5L*c-ksCz1tKPtGj5iC+%~0^=W%A;O zPCr1s%B$0JA@={bQ7E&WXvMhKQL1LnDSK2=NA{ZGOEeGRwz~}XgGyN zC>gm7dYN^M7VdavnZ*+nZ3rgWpx_c_dN!TP=A^!k`l&xNDgAg@g!Sygx$_^JL?{vj zbtm@l9b60y4p;9N_?-U_^498>L-iL zHxf*@4byG(i_@c?0x+)ePttOW$p+|%y&iMEiTRFm1ZMT?6yKoO_1@g1PPITNPK>NQGP0gCbS}S>v|>Q@oqb!SVTd1==P;l; zYD4HB_Y^uX=((c^{?crZ4RSs<$j`@``5o%xcPZYX5DgMZW8?Sx$ZgZvsH9>*{w6Pd z3j$2L6HaWYB9~Tt_ua}Ri2fsz1X?F{(GlzDh;=OL{lj}y{eRgwstvyACdD67>{?GW z89zeW-x`&YSDiLS>f6&Dm01+T_!H!qCMIScCZ?!ZS4W!qm)*_DhR6wY0=9oe{q@vnZyensv`SdgjdI?iu&dPr}UHh6ZCZ^C?mJ z2U?j*X_sTFuCfbmckxRV3F&@5sk+V!eTrYE7*Gr$Dl_}Jpc9j&{**0+THZjeIy#IL zJ$)p*^>4U6I0hjur-TO)IR^<(EEkN5KP2Lt?T~~ihAu}e7Vu8i|2>W?n_)UuAV{w8x0;35_oe+$pdI|AdacK=hAOE}6pNS!2 zjF9lu?#R^d&d;fxe?^k>1mWk073=CqQ-5Ys3VsHe4+ES$6+m3i5dckLx-yY{zSnMF z-vgmGdZb+q?uVF;>v5$tlWz{>q|#hO5T_gR6i$xAg677TbnWhB{SY;3pl^U{QHHQT zymB1l6;7uohF5;wfao_)iC;T8v;c7ik2gR-tabe`0(YYvT3HQCz^{b7ts0@Tfqo5(NUWt|ws^d>VjwlSIDusEIRh5lO73gyqtV*~B5a$x4Fv z`DwI@Oq`2MYGYGLa{Vc2BMHbQy7)lc6V_A7#eufJLLU-ni*ZllZoA;E#wHVtv-iWe zNMGGOI@0TT&K04CFjdkY1S~yh^k;UBbV@h417D~lR*`W@0@LWILRT{0OQ_!cI4<#@ z=Hrs3Q5Y?ay(be~6=TxQzW4~iK1-qw3g76?*u zJnSL-*1(Bp_}I@P0*a$a5K^4ZvSxHlWj4KQF>3uayAiAsaC(!)8x+UK;skRkKG{y1$ z8O>u5N=_JVg4Hq5XbfzGysS~Ei;HgqeiNVypP(?~ z%XtHN-RQibAOW+O*QLShQV3F}hvapcgx9THX8WvXuNuJy1EFEg1FMuz!t&~eVPONG zxe)osY;Oi9J@F|-NuUm9yNmbgr{XV19*Wm7)T61#SpVv?L4asK0-UuFr|Am8&}Lsg z;wl5t{F-;Y?BR0^Ic(qcMFuYT0qzOQ{xzP*pPU?;@(}*YW0iepS5TD>+J1)}A0l3H z0fSEl#&4u^#;Y86gMDn zxPe1ZoLWbj__%@q^DM8Pqgba{ptwnKd@N2d$Ib8*-h0F;5eX@ov*wWzpTzeV>6lZ_ z?9$0kE~^B@;*B+lI5qzhCsob=f`|i9azO~^-?9l2S^YUgWe-7ogO^0jqTG*65MQCz zkD?Qak7wiaBT3M@tFN_wTSCkq;p6=`RBNwSSa>QC^zwuMYSIYY!Zo8-+`8QEH+tZ zM2z-i?c(5*(Nkkj^t3SbpgW_*57FP4hkfwr&s}{`^kaEEvHMCWzUWFCMiYSze%QqG zp|^}(GBc`azz`3!o%^MSEPD9vjvf$3kx!%U?$6@ubzywPJNvWv78&0fSDbCKQ1xpR zq$DvM{r$L3@aG38MixE}URB@amG>!xA!hdTYLm~AT%niyc)05Oyng~5GMA)*Z*eZm z$N-1T*ip?IPZ(JXC$Q6KG!yU24}nD-Ym@zsvFM15zjA*Zc@c73lJ6T&^}=Q=I7Q2) zxce31DRS2M1%7}pygYv26)v}$mlSa(vhfK4K8|J(CQ5ahm-tB-PNXkYQE2%!)lu|p z@fjITtHW-$6)5_Hif0k{G;>v&oSNbjii1=#BP6COLqUv@(=CZViZ3lpRUkEa(+rG an4@2ukywSCa#3PQa)y3!ihhQEMrvwCV)^E4oV-E+ Dk=hV> diff --git a/BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc b/BLIP_MRI/project/model/__pycache__/__init__.cpython-311.pyc index 987014c6b0dc50caa55c69d0382fba48f810d643..ca6bd956d8fd965b965f65ff4c4cf1624ee33b48 100644 GIT binary patch delta 72 zcmZ3=xSx?{IWI340}w2Fu97*C$HX8=zo0m|D6u3tL%%phzgWMxG&3(bBVRu`KP45+ X(J#(OEK1dPDoV`E3-j@v80HTEHQE?B delta 46 zcmdnbxRjA+IWI340}zy5U!Oja$3)CRzo0m|D6u3tL%%phKSMtwH8mr#d}4t=06NtU A)c^nh diff --git a/BLIP_MRI/project/utils/Trainer.py b/BLIP_MRI/project/utils/Trainer.py index c525c52..b102404 100644 --- a/BLIP_MRI/project/utils/Trainer.py +++ b/BLIP_MRI/project/utils/Trainer.py @@ -40,8 +40,7 @@ def preprocess_logits_for_metrics(logits, labels): pred_ids = torch.argmax(logits, dim=-1) return pred_ids - -@torch.no_grad() +# @torch.no_grad() def compute_metrics_with_tokenizer(tokenizer): def compute_metrics(eval_preds): predictions, labels = eval_preds @@ -53,10 +52,9 @@ def compute_metrics(eval_preds): pred_genders = [] true_genders = [] + import re for pred in decoded_preds: pred_clean = pred.lower().strip() - - import re if re.search(r'\bfemale\b', pred_clean): pred_genders.append(1) elif re.search(r'\bmale\b', pred_clean): @@ -66,8 +64,6 @@ def compute_metrics(eval_preds): for label in decoded_labels: label_clean = label.lower().strip() - - import re if re.search(r'\bfemale\b', label_clean): true_genders.append(1) elif re.search(r'\bmale\b', label_clean): @@ -75,20 +71,52 @@ def compute_metrics(eval_preds): else: true_genders.append(-1) + # Valid pairs valid_pairs = [(p, t) for p, t in zip(pred_genders, true_genders) if p != -1 and t != -1] if valid_pairs: valid_preds, valid_trues = zip(*valid_pairs) - accuracy = balanced_accuracy_score(valid_trues, valid_preds) - f1 = f1_score(valid_trues, valid_preds, average='macro') + valid_accuracy = balanced_accuracy_score(valid_trues, valid_preds) + valid_f1 = f1_score(valid_trues, valid_preds, average='macro') + else: + valid_accuracy = 0.0 + valid_f1 = 0.0 + + # Overall 메트릭 (invalid를 오답 처리) + overall_preds = [] + overall_trues = [] + + for p, t in zip(pred_genders, true_genders): + if t != -1: # ground truth가 유효한 경우만 + overall_trues.append(t) + if p == -1: + overall_preds.append(1 - t) + # overall_preds.append(-1) + else: + overall_preds.append(p) + + if overall_preds: + overall_accuracy = balanced_accuracy_score(overall_trues, overall_preds) + overall_f1 = f1_score(overall_trues, overall_preds, average='macro') else: - accuracy = 0.0 - f1 = 0.0 + overall_accuracy = 0.0 + overall_f1 = 0.0 + + total_samples = len(pred_genders) + invalid_predictions = pred_genders.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 metrics = { - 'accuracy': accuracy, - 'f1': f1 - } + 'accuracy': valid_accuracy, + 'f1': valid_f1, + 'overall_accuracy': overall_accuracy, + 'overall_f1': overall_f1, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + return metrics return compute_metrics @@ -297,9 +325,9 @@ def compute_loss(self, model, inputs, return_outputs=False): def training_step(self, model, inputs): loss = super().training_step(model, inputs) - # generation result - if self.state.global_step % 50 == 0 and self.state.global_step > 0: - self.log_generated_result(model, inputs) + # # generation result + # if self.state.global_step % 50 == 0 and self.state.global_step > 0: + # self.log_generated_result(model, inputs, mode="training") # Log gradients at logging steps modalities = list(inputs.keys()) @@ -476,35 +504,60 @@ def prediction_step( print(f" - logits shape: {logits.shape if logits is not None else None}") print(f" - labels shape: {labels.shape if labels is not None else None}") + # Log generated result during evaluation (first sample of each eval) + if not prediction_loss_only and not hasattr(self, '_eval_generation_logged'): + self._eval_generation_logged = True + self.log_generated_result(model, inputs, mode="evaluation") + return (loss, logits, labels) - def log_generated_result(self, model, inputs): + def log_generated_result(self, model, inputs, mode="training"): + """ + Log generated result during training or evaluation + + Args: + model: The model to use for generation + inputs: Input dictionary (wrapped or unwrapped) + mode: "training" or "evaluation" + """ actual_model = model.module if hasattr(model, 'module') else model - - actual_model.eval() + + # Only set eval mode for training (already in eval during evaluation) + if mode == "training": + actual_model.eval() + with torch.no_grad(): try: - modality = list(inputs.keys())[0] - sample_input = inputs[modality] - + # Handle input format (different for training vs evaluation) + if 'pixel_values' in inputs and 'input_ids' in inputs: + sample_input = inputs + else: + # Still wrapped in modality key (typical for training) + modality_keys = [k for k in inputs.keys() if k in ['T1', 'rsfMRI']] + if modality_keys: + sample_input = inputs[modality_keys[0]] + else: + sample_input = inputs + input_ids = sample_input['input_ids'][0] - + # Search ASSISTANT: token assistant_tokens = self.tokenizer.encode("ASSISTANT:", add_special_tokens=False) assistant_pos = None - + for i in range(len(input_ids) - len(assistant_tokens)): - if torch.equal(input_ids[i:i+len(assistant_tokens)], + if torch.equal(input_ids[i:i+len(assistant_tokens)], torch.tensor(assistant_tokens, device=input_ids.device)): assistant_pos = i + len(assistant_tokens) break - + if assistant_pos is None: - print("Warning: ASSISTANT: not found in input") + print(f"[WARN] ASSISTANT: not found in {mode} input") return - + prompt_ids = input_ids[:assistant_pos].unsqueeze(0) - + + # Generate generated_ids = actual_model.generate( pixel_values=sample_input['pixel_values'][0:1], input_ids=prompt_ids, @@ -513,37 +566,70 @@ def log_generated_result(self, model, inputs): temperature=0.1, pad_token_id=self.tokenizer.pad_token_id, ) - + generated_only = generated_ids[0][len(prompt_ids[0]):] generated_text = self.tokenizer.decode(generated_only, skip_special_tokens=True) - + + # Build result dictionary result = { + "type": mode, "step": self.state.global_step, "epoch": float(self.state.epoch) if hasattr(self.state, 'epoch') else 0, "generated_text": generated_text, } - + + # Add ground truth for evaluation mode + if mode == "evaluation": + labels = sample_input.get('labels', None) + if labels is not None: + labels_clean = labels[0].clone() + labels_clean[labels_clean == -100] = self.tokenizer.pad_token_id + ground_truth = self.tokenizer.decode(labels_clean, skip_special_tokens=True) + else: + ground_truth = "N/A" + result["ground_truth"] = ground_truth + + # Save to JSON json_file = "generation_logs.json" if os.path.exists(json_file): with open(json_file, 'r') as f: logs = json.load(f) else: logs = [] - + logs.append(result) - + with open(json_file, 'w') as f: json.dump(logs, f, indent=2, ensure_ascii=False) - print(f"Step: {self.state.global_step}") - print(f"Generated: {generated_text}") + # Print output + prefix = "[TRAIN]" if mode == "training" else "[EVAL]" + if mode == "evaluation": + print("\n" + "="*80) + print(f"{prefix} Step: {self.state.global_step}, Epoch: {result['epoch']}") + print(f"{prefix} Generated: {generated_text}") + print(f"{prefix} Ground Truth: {result.get('ground_truth', 'N/A')}") + print("="*80 + "\n") + else: + print(f"{prefix} Step: {self.state.global_step}") + print(f"{prefix} Generated: {generated_text}") except Exception as e: - print(f"[ERROR] Generation failed: {e}") + print(f"[ERROR] {mode.capitalize()} generation failed: {e}") import traceback traceback.print_exc() - - actual_model.train() + + # Restore train mode only if we changed it + if mode == "training": + actual_model.train() + + def evaluation_loop(self, *args, **kwargs): + """Override to reset generation flag at start of each evaluation""" + # Reset flag so we log generation once per eval + if hasattr(self, '_eval_generation_logged'): + delattr(self, '_eval_generation_logged') + + return super().evaluation_loop(*args, **kwargs) \ No newline at end of file diff --git a/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-311.pyc b/BLIP_MRI/project/utils/__pycache__/Trainer.cpython-311.pyc deleted file mode 100644 index b20beb1dc871baa4957a68bdde66577d4792da61..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20035 zcmeHvYit`wnqW6y;#1U%lJ#!ADAAE6S@KJ^WyiKGKV`?3{0JQlO|x4PWj>VdmTf6y z#K9ecxB#SH+(=999CJ-GB1WbN}NILHslNC_Kh| z;PYK1LA*!M1WkqrFM8`jM40rFBns<8y0G4>4;#FOu+eJ_o4h6*ZwOU{&0aGOlOc=O z0(p%gYuM(s;W$&s9(H&g;Yx2M{;mi)!&TlY95#ol!!_QTaILo%e_KLzVVBnxuJ_i5 z8@vtSZQgC+MsH)d$=ig}t)b>{i?=1*>TSi}woqHR-P?}C_7D~B@OFedy`AAMZ&$e6 z+a2!l_JrMDcevNv8}9S=K|aSl!t2oy46(4iF#Hc7@KND|peyeY+8FOEI^qODJMR&6 z6=VBA54quAdA$8JIYoG?e}z8Jc}Q6oiFgdMb==SU&qn<;!^#yX%rU&|xag12(eRW% zycmM`>M7>UB__fL{h@pm6t5W%ay%QnvBWcUDYjyo^#>yi>(R;9c|Ye1F_AfbUbb-z zL;G$=*_)6=%J$d%tUnZDLX%ONfvm0|=cARjIiLTQKN#}g2r+U4M05T*ChUh|zVM=t ziJ*AeW0Kodt)InhoM59%i^ZLZ6>Z2Ln zADEZxPVo$jI+F+4)V!aiVI&PV{2_lN02%#(z!K{ZEc>`Xlx1Y|>`o;F!?&QZ210(0 z^O$AF^y{x3_Z>TV{MeaO=T6A_Bavmw2L0}jU26)+U7GN!#&oA=9Xv80qjnhjo zF|zSe1ik~h{0yU)0shgHe|{9a_Xr-AN#R*2hGVd78Vk{|bUxO90Cjv+tS7FE>0;!p z4$#H0@`qLmp7U&QE(lFi`T3>7{$*<3e~Y2~RG0}z*<~uk_-_tSIQR5?kfXr)DKzDN zn3)Jo@lncuD;lIJj*qgz$Q(pOp_<4t#WN8u%5oGx@8>B$%TSRhPelM@7VL!YAt~$WU>G26gj!%COn!DQ?7!pD%cf9tF35B6_TOMaoNQiXVH$!o$02~g7{1&) zxX1-qKOdMMHQRR9NOHHJAK0K8f7Mug3b3k$q+&!7I z)g~_swr0W5{5fPL0>vqT4>U$Rbu|$9E=*p6m?dfQo;gm&(8|@F(_}}auonJ5ABIYP zK)y%Bhz0mmo;b;)s`3v_egJtsDmE5_3TfR6A7*`EES9Z76RtZE=2-pu^}wx=NX7!T;i$dEdH2Z znuEnKH|q~DRBs+ldI53gmH;ED-ePUNRNpzkWK*F+&kCr)}{7wkGN6f(#&e##2(1bFrtd1rlFv;K7J#b7B zAL?a8*bk%}P4ll19(+iAsFQUAL$aP@?rXDp>nPpyt{-@^{V1BTr!Zl{7yI-PF5$YGu6ce z_7%wZyNs9}f(qCX402_s=t^gX!0*?cVW-hP{SvR{Ac$@_jGDN{-eK5kyHbl^X9RON#vi%04W5A9A zz(QFU8fZOu9*b;1+1N1@(+6Im8vYJ&5KIzL$cXEf_0S}XNPui9Vl9W(8Cn(Sxr>oP zc9u=>%HFbw7sZTzsGIv4cwgKlHm!uK@!fL?%X;tj%?PA?|sqTEjlB;c5tMz1SJ(<3z2BFp?)=o;b zlL>PUW7ujC5K63iRjPV5Va(xaS-lEVAUY38&O-^qdSy?V$xMrt!&2pNLci{CC7IN8 zX0PBF5gj9vVgj6CJ@of1 zUIT3aNzzCg!8g%0Kwe321zk&9;BCgatoO`bi-mwR8)IE4l2O{u*ck^?$v79P3av8L zv$_(4P={#;Q$trWwM-r3DCH=HXeZ->R8?K2w3m&aRWYeqD_wQZ=&er>cSvso5cWFR zacl`>jBuVZulxxnwIFH+XM+q)-2f4nrhv51^ZX*We{c|_hWVu%h=vWSDElDK`Xd~$ z`7pzBgEvCa8-qi;hnYRQhv_}^?vcRojp04Bvoy1BFEcbUdSmbK$mp(7|L&1pqa%YH z8_*Vqx{y>1vP9R};n7{VL$Z}Ciq#q!4smugxu0JK4{>&#Mvl{p{3@wV7X*I#i2QdX zY#zuM;Y0X{AK>?dugJ$~0t*upAp9FS5Sdxo?DGX9LEh(+t9=~r=Ys*?9P3}4_o2N& zLl_85$QK3KGaOuj#Aw9FFE274qio=q&@B5Jv>_XS`+?|n^EQ54d_FAR`g|N3Oo@=N z=iqatUNb)fh1T&R+6@_CJ#3h_{ju?W#e)jTye-u&m_2DXc#6w)xDy7++?I2Xp*N%= zr!0`-n4ucPB(Xr1fq+&%;!c9hQqP`B29>2E4mO z&RHutdo~HZWqYPu+P)u$)}4*1p7&jkwn-hsQpai0c}8-c0n+N&o~v@LRrO}8dc`V_ zROJz>Jh>iEVqCI!3fOH{Kt;HYlBZ015w)RNG)Z{Uw*e|>H$X9+Rt%=en|N|L-D%(`^Mx{A0KC4M`k?)@DQy7z^US%tc#9<#$hon*Tjuur1W zCnbC9n!P(~?@rIG?vr{BiuN(dJ|3Dc5J(ro6+l2k& z961-jQU!n#JxCnH^2j2VM|4PaRsnK`4hiZYL8z2=C4m$q6D(}94YB9sDDiKNps)YV zA482rewH6KR0XtxZlpG5e35){Q;bNf z)O)20Vf^}-iLWcgq>KC0QXEmHM7%|SN*kb1vVO!XB1?P;XrN%^n@Z)iQVM&3HpC3K zNOmvZTFOHkX;Zqe=PpG5JNMz^cgR{I( zp{TL}m00PSL_0!?230;%)Uyn~#71ys1JntKxXCPY`>8L$o8ekZP;}BpbJb$uJpm)XOV%Xm2WHBwJTxv8mi9Xn}c0V z#2;qlO58gi7mh|j)ym2pljp~epY=_gzvw%5F0dYPXhLZ=LHMROew#pEqh6l0C`60V&2JJ!(`j9*U1}VW8b{X}k7XN=iH*ml#^VCGNc?{hM3g+K z`&Z@l8G4=8{aSe?PJ`NRt=65bb*FEynx1(7+Z4Yx^;UN3EpaL?O~r-daiP{N*1j#( zzMU9PjDHHd!QS_GiOxRB*@t$444Jw7c(1VID7f`{O6UkYt^Kdtez{G!`1;z#Z)Goj zOT2hpx_DhY<(E$RAr+ihPfPW*P(`o1hgKW@zFTx3mfVMh+Vjc1sXgx=P9A=8<>`f= zdlTau&f4UaL|ibpU^+daq0_~E<39r3K+1wN!1~#zoOYtEeB40mfQA;SF4`Za)ryE> zHD5&$;wHXGTcHM%mR`_$MX#SESt@2yuYOqVx>!Yt8f;DTwJpcZUQNG2}s^v>^f1R{0l|Ozfh#vWfb&Oj9jo5%9b>W3&r}@IL2*p zd(84@M9g-JV4Gsrm_1#oE>#F?OVi{E%;^fW6SqF-i5!C-6*QI{-VY3!_%QK_6j`IkcP)L$DHb6V_Pe3S^EOz7o&Od91hs zZPcjiE`$NwOxoE7b8-2Yba+gH@#CfDJlUh_xuX6HBT-32E0HT zL??jWf*|Y1-+}1jG8U;y3TpN$N`4(Bn_xLD!KoGdEzo3x5g{+wG2t5wO@sK!$yG?4 zqd_)g;SJ|&!0yQ{h1n%kTE7@wWZy=W1;O(;%cL^7QV~Oah3&MJkMdwUM+=snMa3B8 zp+U-y{LWdCd1VJSwnJ~wmV8;Z(o5m6B9zGnEMFRNhZMOFOLd5T7vPtFE5e<sngzv8n|`{lgCqr~Fx4yJ&0AHFTsq-i?>`*lR62vn@NtmJz9C zBr%!m8(JOD_U#k;_N7jyPClj+lba@@sU>mxIVyRxV|T8r@6kc23*3pduB+Lut74Z| z>hf+9y6WB7`~I2ac&aDY)GIW2GSfoCPH?$xooVJL!M_Op_~xUVtDQfNh}(9htY2&z zQ6VeQ?UA~6OI@$7bzR7IT@bsbq^_you0oyBQMFk`>>ibNoe*5nrzeH0jBpmYcSP5axGSJ*eU z1A9ZpLG=+YtpxNO+W-l65Wu{T&{D=vT(8!sgY2r3WD9x9ci38~1Z0Y$B#8FI+EmbH z;I`Nl$O#~mDNjkK07sFqo`U@@q!(r2l9ai1sTC6eN7wQ-FM##N2*x*6FAD9H_Z+0k zDJ&D2&_1~&rLrgBt0*CU=T+EKtF1RCkZAVMpo@Sv`JhvsH(iu!OKq07`Xcf!*nJE! z6VePI8_9#|2-4+z6};lVHN39B6ueGt39m6V48Bu>%{F!X|0UR*2I+LmJwVK;#XX>{ zU@-%E(5Pl>h8saYw_*e%H$4cfWD>wFJFuA*_z;*%;jrHyLR^WrDJl#Hoc%)wzid5zzHw_+#FucHfl*2jf4+Yu^FI+VAXxFEIHj7PzrAsoZB<}$1Bp{@u!AQ zEYeAzIOr1sbas#y2I*XDd(sLlWUXs7+che79gw;XfEu{^GO74^=F!k1HGo~$=(S;%MI&~;@2z71r z(9fFwq5W^$#nIE!=;^i5^V!k!;^+lw^g`-%`t1y#emniP*s=%p%;Na>=I@!4O`tWG z%q^gi*fbd}_RUJ@a>9m*eo5gZ*OktFVR^23woUDR5GPNdjJ;45F>W1ydK@)7AXbU@vSv>G0T!4mh&G_T*&-8bAF z;Z+(5XsO7DUWmb4+Xd5+NT!oO#T=;ny;!48!JZ4G;hcNM!s@C~If zhOp&jAY2A-vj$>jZ7OZe^GkEgtkD>oBRbFknbjKTg(5t(QpIk9B-DCwZ7S%eYWVij zcw!_#TLZtiB@crS_zqP*3*TA#(w1QN#4NO}q~QWBl^S9Oz)3qm%(At|jloYBv&3y*+h0151Ahl=R5h2SQEH)ta1L^AE2Wo02*cpCN2PYX zVTr{auLP?N{AuxtBqFvhVkf~FZtL#oUng#pB=I`50BzXQB{VBFQD$9wk$EkfjmodU z?q_Gi3A0i84Pdh)vKp&=274Wm_Yt6gPVJ!_2F5k}5}OseJqpy32g_fe7C z#T^kAa>VR0N6d_B_tf}Jz(|Ywo!1x^9cBfsB*h@>7e2`ecX60SbQbnH z*2Y|?mXK0}x~^zQuB$YZTWdAJm>7(pJ2PLQ+k`6*oc8b(Zj0)>9*GQ0Vx7=+k2V7m zaHm1Vs$_&l@;zph>p%thH?KAAyN)hNT?1WIKZS2Z*W?-}1nXs~}Kl6?H6s4Brpeo^A(0Q+$J=mLlj>20CBN0Gt_kAxO?b1oR*2 ziXe+>7nP;GfECo7k zWC4Z&=8eE;oLEPWL{x1PDsdHRMi<)9jLlr{r?#WB#pqt!_ASuh2!*taXikuztK83F zC`HSMVoQUk`hv4mkXP=Vl_tmoz=UgM+FDmx%ip=)DenDxBCFG_<4*t2%y^2g328%Wq5FBTh&+0BHx!l0o7iTazk#RhSHH9zoMR$@`s|^VA?>cU1arKbR$U`1 zlt9LBZWmPj#lH~$0@_pI$Svfq`=7Y~%l3cV{`A6s-tqGt!j)^n(QEl6B?I>H#!9y3 zHN^<5=cZ*-AR3ts&dJRPzu<0@FOR5xbXZ}QVHr4Zy*1#DLN`=^YkYCQ<$>0OfW|(8 zYS{O|%L7f>@zW5Y0UMivn3o6G1q84`C|Uqm-xZLS4&NZ!}e#2kt2jkn03I%2x~h-=Utnmv>Y{cc_OBL$(x$Tv?iAd>!qlQ1BZW;2att znt|c~X~9FQ2Y5apoDZOJp?BVSP$oPs56S*pDE&>&Ec+jzLgm)oH#SVRd>-BNf_whM z>CA;y@+m1?ctf~8EBNNb3vH(g7Lpb>c8W(!! z9nPdIQiN8XR>cNqQ_eymyA@c^z=+MFofO>P^GAFtx+gn*Kx@IS#pI8X$`x>1A5OQj z%IT^KpyBxIFVTe7qAHA-D=CLsD+GLRqR<7dUYCcFBaQBDCx+%p5cn6n4`c4kf zTzV;h?veW8d~E&>y*{wGBv+%t`MdQ%iQ%-Y+@iLW!9l19&)k7?VEzbPlVsUN@-tAH^8xEaPS?tzXIP;AbKTt_l=ZUx1btoxW}F zOm|D2!(!Wr)HbsEX0~mg(6%pmDS7Gd0~*%G1SMt(V#Fi2B=8U2k`CQUSU1f?L+g7B z?<}NSR}EtQZmE8E!j`ku3%0ggQ=8N@kT|pM_TYOtqWgg4KJfU>tov2L{pyoz;HBsk zEnL1Pn434MU2D}n+3KG3tXRE6s@{<>DYdLS>y!MuE9rfrvtM%d!<`&Q$9h{&ntrqZ z2V{lO%feN!*mg~7yOy*iZ5wS=(guo8dp&4WElolr6oB`-tMmQcYR>b7mL@JkgxGgQ z0>A5u)OAI4U6ou{6BFysn&h7QZ#{SmES}DG!P%Z`Y)_L98 zfp3%Ovz!3cE%Mi$Zsl~@XGE_9KHCN|ORKYD+nCfgmbB$+TQ`ZS>it5;F8p2x8|o{6 z#fuHQqy~sFx_0F{cVuRz&fTdBI3ElQr05>o*7lvs?mH#!J0tBo1JR-EQWd%GL0~$%Ro$cW5Zm0j);y4H9uS*{q~@Wu=DpeGy<+pI)I7S@JeF-9 z6Ppi9&4-iY>#YNsWwCWsY8_2Ze1Br2xgFxP-P5~BF(BlXSv(2+8v=HRAmF{@JGGhFP~sw&jgJsbW4RQJ!x3^4f7=%0@M@W8_Z zYt1{d%{#>AL8*B#IsPfYrT4FPi48+i!w_JiYsY$r8@NQ#tzP+0GyiesuYG^xgDBAr zN4V$3j(MqLK4nUoHahw;)tQ=1&7)@gi6FfN1cp(!hI5_0>7MjpdNAWgeyReESV_X^ zXt={@w5m9qQ0C=_S)S{ek$CV(5Jym)dL%QVdq#54h#fOh$4q`CFr0?wcbBDxS2AyA z8-|32px+PY-rD}J=giK4f?jLw?Ak~qD-^8#r=UmOH zBk#VRnAoVQOU|Zd1!s?7?%`&D!v5SnzW>}l{V%AN(`Mr@4%#97OS9p$-SSJj9)+DG z3RfLtK=A2@bUYn?*Wug_)WZJq3#tOtJxOxmk4VoJBe&whei|Kq-#=omx0C z9rF30%h((UOoMn8^BeXw%8ER+FX7HZ3NVtw-=IOoZ8u4ydy_!^hO_#?@}?0- zn22f@{C+}D+W5$x2?*3qscBen4NFxcn-w^NnP~6YwBVqXsBcTr>5+#)I2n_#&`wyL zi90FdLp$6+P2X8H{*7Jg9eaG|sqquLboi3syd+sJZ#r;!+_=((ul5yb3qjbk15ceY zw^hnjDg2``1L&81ShRg%k?DywArG07gsN>cWpZfcds~ z_nbv0U%5#gHGs{#@S8}uQPzrDK}Q$1%7yIJ=u$R|PASMWnrf8d$k3xma;=)wLL+dy zNXdvq9kvR2$O6XJAP>ogEHZS+PVOfB1P~h;DEurAGHu`=rI2lCWcW}d!?8%;FPq>B zIrNN0Dt(OkEK<$l^`Rg>UIG~V@k-f_QvVKlCy|FHLD^@#f#PKxQMSNr9O9?cAo>M7Dsrzu?snI_F*2z1Huv p-)rC0)spa=7vO!ShZALmh=B?&7~@(q~d(1XZA)w zHz$u;V0vbHy8CpWKBv!jPPbL56eavleDs^lGi6^sJznO4y)86}aIZI#`MQE{tA zRlMg~Bkrg%DxUMLUGA7M=8hZV;=RzCaCaNK-AQB8oie7}J;ok)ud&zNXY3Q-i>-&; zX=B>mZ|oQErPcxWVdG)(TyE*^BgP}{LF1tNsPU+K$T;LaW<2ITZanVJ7&GpyF^h5) zyLv}5YE0TZyd~Kq_UIjHLpGjxQL=a0V|OHbjFmQ2O^b_X zt9vi(9Q0YhCnhhN8Qjz_b(Efoa{eoR+oN zV(}ER1FOkg3)M`wZ8DGY?OHZ|c!2r^LE}a4ciO$K#QO#cq}_Yn370PTemkDnvBq?N zl-ujZ6^{iNscEy&sxQS8FNcg%XHu43S+cl|kxVUGEz7H;pjEGTxK&>@gSyXIoNqjp zJYo0+8f(2}1wkzzSFXNs>704yrE_P#|MJ3%v3lBDjq_)~9l-+atWytT?E-im=U-`u zj_+BmICHgwiHS3>dU&rZG^6QH#hnEI46a}wl2BTfBMARq2)?bptE?-L63Gn(qM)s9 z6m|Sw;ZD;*vwAvrx@)cKOV&-MTe{0!pRejIX06QWqV&}zC(!Y4=`@`dxbE3{=R5mZ&eSO7x+BT!zC-!EzWe$4|NaH!0b8E9*1^HpZ*rV`*N(>jteek z&6*;zYxM~c+5(DwC=ofmJKc#}dc z%0$wJvYrhy^kf(8;=Ag4F3LtZyL6AHzr6N`Vgg76LVtyM&Kl#o<6){Wy^W4nCuyUp z4wo46+nrGN8{%cD(}ct{8&;j^GbuA?FuTnT_@>YFTAR^l7ckvaUz9?f)1r_U2DKh> zM_!jC=N=?L$ih`<$at{gw9TN+>R5A3(OxjbNa8%r%N1heG?FHmc5Q;VZc8f)pRi?5 zv6Y`GH?%9#eHF{Qg_TT*8YMz#?klk}chpb=c3ZqB$t*s{#ZxV4T-|ru&!5Emzu}VB zzEAI1?z+NZ-ok>CbdFdzcdF&rtyb{-+ymwMANsBS?t|e^pf5>RpX@!64shx@pGWcx zJsd?6XI!h!eK0f!b>>(~pS@e`EuI1(dmCcGUQiFV>C&WN^S6(sr7CT2uJ9(KZKcNLhJ`p2XB2_~M2SqhPWk zDCIH!9lX_X4gCZxK~C^75+)gtNJFt@L$xQw`k>m{oq_?CwzDwX89Qg^anIVj?V??} zBO5t;QkZnv$cyr1(DQ&GQ>?gLVr5ogRW`Cb%62uBy%Jq)jE%DiR@p1zc{iT%6;BV` z8(2{6UH3Bfm_2?cV@$pc1u~|v5bTL7XF6E8-Bgp;J`$=4+w3$PX6uW<1h$T)U?~jS z!E^KTSc;cAizMdr11rIN$Sp6x;^{IT%rCb5#rdNrjn=qD2ZGw?WXnvUm$rWx-tgU||{ zy4mDbd&wj)H046{Tc!`=>pE-r=zC_k+GeP(1+3NJ=Rh&mu=>G1t=ytJiMeSC3vZfz zG3RGc@CJXOvJZY5*xtxEA8K)0a7*|uu0piV$mQIgA0^t`k{xMXC1J2(Twy6xTTSm;de6YabwuIUoTbuek)Z&AvJKcCB_~TYEDC7Le`u`cnBj~OX_Xw@%oYC zP~Hp=8UTpj@akG0kda)_Z;)5(eYW|X@JnRx0ht<>%%tspDkSx5%3Z%UB)vO>gTngd zqC42c`mjj+ZbCite-P-IuO!R^qTBB=C#)2g2U#8$dLxLd!@c5sffJX;unSgl)OEt@RoSi1TUYDk34t=T+gxKe+lo1{ce! zufXn*+?eckowY;ZW)1fGHS{N-PDuwx6%D39{3%G4QF#o16UwBFKV|Lsu!;CetuNFG zA#{IaamZ*0gYge|fJ-dFKq!P3U;# zSS3ifITC4`qgyZ@zmWbCM(i!QSMqNojSLQ2*T|uVHHP#-xNWe$8;ntJW&RN*uo@i`N ze-RlgY&CdJ|4cqo+r?j&}>N$^>=RGO8w^ zNSlGGc#n&OGm~e2iJB{zO@Iu~a#>sz9Ww*h_u)+P_>qgRoIQ8JJpamN^UUd2ubjSM zUbtvpJbmdBLMZ0K>5J#C#M#88@e9f@rsaLW zfP0Xx9U2ztof|nnB*TPud!T0#UTq|8LXHV^*pOjcfn)*$Zz@|9AIOG~K|mrE&hdXo zS&}2jj}qkoS2Rd!9&v#TzJ1Tm!g0>57tlr#k}Zdj>E--#0h0PrBrn4SSuSr@wj_Y1 zgFjN%OY7yR@Czv_-IRDcDn{kaDq@iGgmh9`FWzW-M^WeVYka(|#>dc@5&mmC7v<67 zucQ3tsGZvq&ir}>BisdemX9i1Dri>$(uz@aN?O;_(dAc^4@eWG%`uEEyG7wHw1>Q7 zXi>Q-(Gz&Gy-j7HO)aXT_t09D{xP&SLpl~AkWedFPXKrJti?0I607r9uqkE;ZDEtW> zky+h`GtV6xrJ>Yy;k)UJ2zcFed_GVoS#Q6=h~HpD{^fY+t3i;S=y&1MFh=a=_KG!W z$)n$bz$wDs#LCsD;%tK2#8Ueyxt+Cf{sY`#duU&wg-8CXc(42IRk4cqT~vONJ}=Pc zEYzukSa|s+d|7zrsV!mdBC~B6yC5DV3u40@aoiELN0=+#Hgg`Cb6(aMMN;s8$wO1M$zuL?y9I5Yr>>u zM28Ynvy&kyckKuf`)g1LjvkTWiSr8WEC9K+6!Y#{1ei>=WmVl-%1J96?oyCfO*@m6M z{jD39Z-^3`Tq4cu-m$`U3zCEd8}d40NArOVqXQ|=J$QvCyfqrT4ElKDi-s zcck|TEJ`)XBrZPeKxR;!wMEAY)-s3~+4B&g7r?)`A_6RsKLkmzzYrj=FVhKHs$jc; zqUfSTfKU_yPb@Aq5jca0iE6h+Qyft*hqHiUQg*B*sfRdQtT`cbf#Ch-b_^qAio`Q6 zQVo+T1RRleiT@rYFH_RPE&Qi={SRCL1wkdbs+J*LIwWlreof+A-h@Nooq+qXYpMF5*lp7h_c+UbF^6i zxs*0hZ%ckhzKZ}1YJ3_Su8P~VX}ttq_&xeR_D<@i5_#jqrL zQXK9H`jQ95Jo;RMKKxUZ7uY@TVgCkS1eW?hem|C_u)0C`HR|p zjY-RUXcP!;kGxA*t-M}+9X!<#sNGOy$s2nLV-0?F6oE?mdjl7+@;+3OnZ+sQ+E{!AY!vVA>e7d``@aeWwSAt7})G_`p#qRib?LZOOkZ|bo zzU7C=Kg6j2CaR|Ir9BA1!%LjrHc}<{`$1YjmJ0vpsEi%A(NzEPwdn6 z0erYKriA0@c(hsi1;`S+c)@UVQP9$7uY2CyMPZSz*M zbCmDFqAp$Z4qK*Ps9#S!@9T5=nI%?V(O3NrPgTKJPKXQn3@i*oye*JquE?u5yCgZ>27z)Tz`2$uX#izY`; zI%4a$mYBDV$xVv}Ml;uE33$~jcUjn-VQCcyEZ*!MxEK?*Gv9k%*?t_svL zu-)CQ8kbTWO~YQe8fWXi*KnHgLp1-`K{r!DvM74gV4Pt?>hlLi`@$I5=sK))67lE{ ziob`C{CyN1iPQ#GN2q_Rekec$8 zgs>f8aY`QNJ?<+aD6}AG-B$%$`Oi=}mT!W$e~hOGx^b=BjcvrK1A_R0g!p8i_z+o4 zWio0pX@5Z^e!iVfvrDHY^Pk}JHgaEkto+Y=R6dU=iKq~nZPfmM2e;OxQ`?_7jin@B zNRAfVJw#pj2k2}(Qnx&FKosK<6T%BejHR*ZIEO=Ph?w&v0-1wyh_^qZp^i}#8KKBY zpgiaK0+aI!eHjgKLJf;Ec7SaGZxg505w!@Y2c!mer%ne*EX3T?qhxBm-HCTm<@AUd zlmij?czV!NLqO0A*=@v1Ef2>yIRAZWR-=Z5H^Hytom)t@IkEhY=o5)vdL&OAq2hl^ z8RMjo;^+ic-!%FwTvBw^y47=gbE25Lh_fQe5h-^u34W4ZB#tl%A5>0A_zx+U{D^qu zfwciA9Hn3hU!;SFg$J8O@Utj9*)c^^^9YHSaoR(J_nL<{TNIwhGzCks2Y`IMPg9ET zMfquQpGH7c!`oc)4JSp8DOr@wRzVp~qcb%bG4S)0?xC{)BsMq;P&6r; zfZzh2%cgmwW3|#RIn#tq#@7BBR2E?i5kE*G2{bdpol17ovlm4hd>^tAd4OnA+{AZ0*T@G$kHQcFx4vxH&qT$Vojptq7dFF-}`0@$`eG z=4^`8Bqu<5j*8@|GEN5I+z?#B-eC})$46t|VAh;)6>+Ll5as0KaXL4%Qz`P1=~M+T z5sk5-W|Hb>v8tC4$|iq+pc_l2ge?DcMB{sWl%*5|0EcIOxc|x%Bj3|GZON z>K)l+0jD^AoW&Us^o&nYb(%oFmlB%hWd6iVib)g`UZW2mQ$mM^#3Ur9K+L2_gv#Vk jB|`T+R3+IDqd3nV(ga97`QoD)y@ diff --git a/BLIP_MRI/project/utils/__pycache__/__init__.cpython-311.pyc b/BLIP_MRI/project/utils/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index eb449ffc9c5fb03ec66471fe6f6d638f8ebe51d3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 165 zcmZ3^%ge<81ZCIPr-SInAOZ#$p^VRLK*n^26oz01O-8?!3`I;p{%4TnFAM#G;^d;l zlH?5i;uQT1{fyMqjKp$%Cm+v%c;6sT{eq(WtkmQZ{nC=moMQd>_{_Y_lK6PNg34bU jHo5sJr8%i~MXW%BKvos=1BnmJjEsyQ7+^#ZGf)fw&q5}@ diff --git a/BLIP_MRI/project/utils/__pycache__/__init__.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 8011179d38319cff544c3be7934621a03bcfbda9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 147 zcmYe~<>g`kg0k!D(?RrO5P=LBfgA@QE@lA|DGb33nv8xc8Hzx{2;!H8enD|^QDRAQ zhJJC1eujQVYHCJexxSN+XF$Afkf(k@QGQlxa*2LvNoG#5etdjpUS>&ryk0@&Ee@O9 Q{FKt1R6CG~pMjVG0M^7J8vpzvSc~5WyzLbvSK@ypK<)y#E(3fIJ1;yBr)bgduJ5e zTB)lB1|kRUu8b+xnu5>C+10cW8Rc6 z=1ciw{!~@0Diw$YSY9v|Wc9gZb*u*PIg+)hx>#MRK3322oykzDA=bd+u4H3sLu^B_ zep9Ro>bsN8scyHV;d#!brS241sgd*B;R!+ z`R~{W`56Aja|d1@ml7n_3GJ#3{6$;_9Re^m91o1Em=Cxa&LH5dmcw`KfCv7@6Wb(l z=Sj5Y3rwDka=PW0tgu9K=v<7$RRZMy%k zOw|jS%W^s~DO0`Lk~|()RQPto%T&HVGFUKY5*64W^5Xz(1qITVR z6uPa*(3$h`6nvWM9?UtNiAyqtCN)Q)r!+A%p~_OR#xQV``5=!4;Q#oa0Nf^SqE?hn z!E9OQMe*Cn6nV$2GsPvyl-HLD8F7cJ^B-?(d=O*9u zm1%y;zP5fk&_UMAze29s-ym1GDIR8BnEaV10?(MLjw=Jby`u?rY+|TqIFsrfn;0Dh zvW~=udC%@XX^*sf=WyRp-=2{XN$%e(Z{N9ZXm8)n zeY^I>ckkS_Z)dMUhpo+#rFs#F!x`Wo;DMg;YaOTicJ5_8(d#66oXW#-WFGNKoTi}@ z%4F-Hav37q)XcW=xU?<4Z3GwtQbn#gej&>1&at=>S5>Opp$csTn(6jYS&a(1+pMU& zpiGR@QAPJ)>@s`-;{M`?L`4NA(j8DF$&!MMfaHmMy{mUz83qO$9_v-4-m%^>SssgD z={SpY33bySnZuq6G?>)4@RYL#7|rSnr0?c5qm6}9VrcsZiH`y;IDiF zV1^VtHo^6!hxlse_GzuVvfkY}@9qNOggyB{c=p=O>4h;ZuzP6>01U3(1qZ2b&DHg- z)a_ZW+oSD$O?%_*Y@L{^6KBrkPn?=Lnd3KV9Y?bK(E`s4Z*ckQ*0~MyEsHL#dhgOp z03ggd0Ly9a4Q|$*b9HH(Ph?#$YOWW*`nv;`jMp_+F+B8E&zwqBt(8nb{BGWd>Hf4L4(8ej-qM2K2 ziEqhv2bA!~Wkn3CmCZl@9qst_+fCU}MV^iFbQ55wT?jTKhyu_BSp4Lq?qDxPceA&c zkd&yC2{ohix-ISFnQ>UActuW*=-yODiYLV()Gaguh!vy^=p5RPU#S*BwG7LOBC>i3 z_1a|bw{l_6ng*{T(rExQ{PN>3o>8CJ<@00f(R(?`Vo`JdMu9GS5 zD7gbo%0`iO9<3?lA82D+-R1)a;<+?4ZC6d+BSgi%6rU-(RH2s^O>Qlf$haxK%slKW?&}BX=Y)h;>zV`Yir2+Q@2^C=f9`ICF^v4e@3Co z2$)~ri&>tj={gMaOzkU}=4snacN?~YAtMHSTF!aeZ!qbz%Fr`4{l6fKMG;&wQ{|YM zWkz7vD#PYQN`P@6!M%r}w<55mi4+YXv#yoIOcR-9rKzyiVMJV|!YXywb8l9rE=8@p zzh(+5D?dfkHB<16?M=@7wsvbM09wtGx4FbNmRua$?`c&@tLZY?$}x-c0d1U}@PB*( zz$wCkO+7^}bxdv;35*mp8d%T#FT?QX479m zrYnjXC+pwwpmPeQXvSX#O!Ab!8j^;us}R z-*KLNMCc1p6t&TP@Y3x_Tih`fP5Q4r`z7#~fsbYU;N-Dqrl8rjr;O$|fL+0z2~yJ5 zP*n!`aNC@H?p(oc$z_CrZOvZnpWD>`dX-g~&~7NN91DFOKU6UmItUq)1K)oTA0fiy z0I=tLHSmB59Hh3nAmB`!L&|r&#YXZ7bDpqlyICJwvVpN-S*8qac{XJ)LgYhRapGHX z(r(xnA#d$-!#2g0rOj$D!!);@FA|2MJD@Sp66c!HM;DB}TrXrO$Sje_%naMi{g0{j zmx%HJfUlUy1W{y)pd=P^(n_d%In=#$@m~L3F&jFN3mpJasjO|I@A zQ2|VLJspRPAsX}FE`2(`r%1Q2V+STn{#UWkRfvz~&_ zWC{mqix!-ae1z=r1i`FOs3MIKKO$H%76ZLoK= zUI)Ua<0)Abb&rUCfeCQch`LV{e>4$K7E7F>C}oC4kyZmy^?)eG)9DO&H^Hr{K<*~U zrG6~phccNYuz^Zx4Q6&B@M4J_?548AgE2aD-&6GYPzQq32vFM@l7}t}+F1fj0{#(F zGZnuAZx>uJQUIPt0J}(Vv&MHp$nstJ9lK{v%wC+^wNSf|(Q3El{5=}K8$y=fosV?R zyfo)psGav~d?$n~-QssGpMTswGJG4p=sedV`D?qB_^$4GxWG=_|$jdB##3ut^i-C*Py zhKTuKEIc3L)OLAQoVou2oK1`Q{o=i78 zDDp@?=98>GiNFgO7#V?+7`H?-<02gE zsKb?h10I^8_3`*Ukiho^usgUZB0`4o;9-IYCB{?fE2;41t@gkX5wWso3Jx_y^Ly+2 zZ=eRU6?;@Xf+;@W7c}#&y7&WORv6*Hv*W6)v)ac4c#&t>gNwdcnh}UEvctwwY3Y3j zB4xi42ldLHNgBN&s12o>1n6Y#(_5T6-q_hHp)97Gn{C@lbB_0d|68R1^%g0wjB z^Ng8cET4Y*I#i@F0N~1mlhI5{cbqfccA6fgZ(@nSItI#hBBlGmkvkM0N*F7sULAio zo`B64sO|)utAdMH_k&+lR!d2Wq6h9}=9-T>jD?te^o`P&*+-}ZNd?7Y)!J*!)@0K+ z&*a}=%g+IzJyAtAw9oQ4{rQGgc=;aGZkRu?SeLEs%GGv3hAY2eBfP5Mq$so>CiLEO z0J8PZ=jxw_Y|n%G#)Z0HG~E*C#aU0jL_8qc{m##}z$@EwEZ1^ume0AG07L%(##-|( z--@eg+0~SFg>$a3<_bd@_@@to8R*{huzAxbwM(ym)N;2a+q~zmJMNGD)wa*JWt&dS z9=&;L7Jd(#I!n>@8H;JNzlZU`MFH>^TSTSez}70g>*x*?683-=K%DHMdmv?&ESh?* zRE}?}7J>a3+ME@dWYNS{*=n^ufxTOcWIGS%8=&#+lh2()Hi?-01HM?sKg^v08!qPI zK)r7fG}M1g(9S}nXaR403;I9J!5-0~Mk@X6Q>!$(l~!O0ca$t)8$Aw~bpeE7RE`#< zluOWI$hGVs4KsM>TDG*Xf5oQJ-10m`RA^|kea=QOe{DtZyG`ATuPpJigMXxccx<8b z)~WeZ*+3*0h`@fKvL54;vr&icmeUh(oDB{kidHARjKGj)!w*eIAqUhY8Zo-FNMg(0 zCkZ8!ZQn7Be}+oRbpX(1zSNZm+xMYZWVav7Z9m3(2s@gMVCVs0f?_HhptZS$fjMVBSU2xl2}YKKk;R5=Fq#WSH5Lj2v}Fhiwo>)AQiAI_ zlae4r(9`e^yp0EVjRW5RO+gY~%$URK5yQ9#`w~=>C{AuzOMzn1zd~&VSJxSm57gfL zvF2(rdSFTol;N&{^+1ALIP-C$-EpI$AD5ZUKm(Or30E+nD<5Nb*z+D*00df@nAexA z)$(Xfqpw66zd@cxmN}xLja+IQ>4+1rFdv8qLoLY!42imki z8#>Uc@+=A|CXD>VWAF=kS!%!AY$5V;77Y*E31!6 z_X$t}{Ah;x#!;+ep#)Pcr~|-LVmL~^Cnf$9a478n(10s#wU>6x)#a)?76)_HU1sK5 zO=u44Sq)KcTz7+ghsW+i$~s2qvrp)M;V?>eX#Eb2g~j+>e8GL|vXKJFhVTWdK)6Kb zydoz?$6%|+U1sbS4KTf}dl;i1qclU=5~mBuM+(KYi>=r6M;Jp>(=ce!Y^7HaBoN>t zs5^(pGH9UE9!N!9#*%1HNLTe5(+Cku8wy=ugwQ={xG!c9MfVm9;h^@a?orsKJdD$6 zK&YUG{2{YlOO7M2D4HBU+48Ld@$ZP?Bv1ieMtN`5T<84OWp9h-ZFx`=)>?LCYx;6E zeVVTi{0=1-bEGpD*|QSazZ}`0jU3EH4zBVl!8%p&kl@Aw5!{D4xG=HeYhCuWW_^*I zFQWM(585~Vnf%F-d;DMcv)u=C-3PN>FXXyj$hIHKwI9-atq)8p=&CSPjX4@2DW6wL|`>KYb=TQAMO zH0#cn_yA^pX*Ymu{f=Dy4sdok2wdVATM2b7hv4egzTXaLubk6DUD?q2TZ~_kzadw@b?H>DzJH~D|8o8QZ2iGp{lQre_H}yU&0HY5v?a$blrg?K3j+EYkoLmy zsF$9B2t~7z{r~~4Ow8Xw@rZ(+Ly$(mn2Pbv2bf&t7Q$;N6v5~4SAGUy#{3mrq`o2F z7|u7f6}(NtrUC)*uqHHjWqw;hzzhco)D)Z;bCH@*!HqEwscS4(XKf77?-aoFF+mL{ zG{bb1fMbKOr$E*MHgML%Xcv{Cf4Zx5|4V>rwh-8`5?%You@$vkkh3Sznd_9$E#w57-O#>)GaJ>cJylS)heT2M6A^iE5i;cs2p5VJ+3@|4}W>GPy6POC(movkS9T{6!OG7!+r%v16NZZCg>wJZ^3rb#=*8|efWsw zKVk5Z9jag7P;BrzPq8uPD^U$J^bwar({9%p*ky;-vI3KZ+C*M%c#O}WV$>Voz diff --git a/BLIP_MRI/project/utils/__pycache__/data.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/data.cpython-39.pyc deleted file mode 100644 index 952195c9ad6af77a4600596f9b9f6fabd938c9f9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7990 zcmeHM%aa?$8K0L%qgm}sUK_9RLm048f+gbzVB>Ouy+E1Ri-XsY8dy(8>ej;^J=`&Eq*9ebiX3vu4GNs!*CS~q?*cZL+jXy_f0T8?P~aS|N6VywWl=gAJiECS!leBC)%fL8q=5_YF_fzJsoW$G@7Po z>I5@GtC{jrs&0i^)AnptPlf5G<2lWYmub45+nn;InprQ~oc5-hGu}*d)|*wZId4vl z*TZ=)hxP1mOLM_nXm0hkDtJ2F)_lx+Ox2xmdvnoS9E{)L?ZEg<_;_=tx3js++tu9d z?N)tmxTl%-@`G_tcu!!Q9`5z_vMH9mrFl>C)ax3XW;3@mHgnUsse4ah#B8$CU{-SC zIh=wp#|Z{-tb)1oEQh%{zVoK(JR&NiZ0rkDA(l^Y!{oi{siqumNZnOf;~Fe-Ji zg|=|8jdae$Tm*4j6bw0iii>!qeTBCgJuc+@MDw*Eim~@&=LKI8L8HalXymCN;^>+^ zA26U*yAWeRUb|5+WqJu{M;ypp2%6YhNS7dI+X3SOQ*ukdR5Z%nn6tr%3M6IpOQDv* z^Cq6Cj3UwUeU^^M=?zX09PHtv|rVCG$8&14p}nQ=nv?VcRDX1$~J z&0%jVyIw!2^)1jp)%&HJ$FCm8o#=R}RBOcbZn;=#H%s+yt%d_ug9_*|)%B37!G>YUeuYo$|X&z$$qy>+J45$(&o5|_GhBaGD9 za3r~x*1raVw!5)U==viphUX$2qd63s>*{mnjP4qSo-?-_Vh>;y>KOlxEX-{{xkgsS zGcH0#Jd;~IMH0YF&*o|F@C)mES=$A?A#ZFqCR3RyL(?0V)#J<9#$X6p9@r}0uj);z$Zs?YJ zEWEAHZP(FfbUa3{^o^oY+^{ArwNvE*$jl>YC-k16vyJWOA$mt2qKLIvXF4-(X?=Z3 zyQ$M3^o5lK#30DeRnO|uni(6h**95gQM+T^Rb$tz+pvsDdr4c}y_QPaeHxR3{luI4 z@9MgC_3YJWKGHtYs|K^Nx|6I9-lT(zX{aQHV*&@gX z6a56C-~LYHIuHH#gRsk^{m1hk>|c2fFIdq^y_m@1+OwO_RA`VOJlZr^RNhj?xEVxO z9&K(SgAYAt6U@J-!s;YDzq_D_sEW(#I|<7pP2EIb9%=rV&^+wCwP%z2B+-$id|1hO z*gcf?P3q7iP5obxWx^1wj#}%r5olC_tIgr(u?GC0Jh^F!-O9uJ%MUKc(b!26D@!!$ z8?s_@j3VL#CT!exJ@f8(Yqc;b`x{s=KKLP~ZeYP9&L2tUH_h8X0a&#O*#4o|OhD;u z-_vTGR%`2ITdz$-Omyd~d!O!pgyitBR(Gz!$71kU?`U^5+~DpSJgc`g+}63PhKe~2 znH`rv3O=nX{6+-V$0#0@>4Cc;ouOazpy!i0QR6Lsy(3P3N*j~+LEg^X{ob)}@oxMI zf>+-!o(S8hMJJ1+l_t(bek4VM7if9%y%70ns|(4=n2s1)`7ysXu2KwVZkS7Pf|{Ef z&UI;SO-p91pQm7mgBT)BoPQlElz-=*V2G~~33g8Y7Gokp%GP_YKJpPq%M{mFgzs9U zjWZ*IDrY9OQ|w2-!s%Xuq&PGLB$yJC;wXxHcp`GOZPWP5(T)4Pc0lZ?S?c)I!e^%pFK*gGb^y01^@%p_d)i_5G{eARP3h zeV?@}zAq@wm9xGdv|4RM*$D3=Kvf8nr7o-7ZihI1EW|t^uHqU>xk(e3(t3oG5SJJ( zh%bUPMZ7}AdW5o0=R7)yfGW&(4Qr31Z^KdN^sM3J1oeVlq$kjY(V7*sCw`>kFXM?$ zq8LWrY=$H5rkI-B91*}L7E3YvR?17WG;R=?{kWo z!SBzAsVlD`iQ&+l3rIw97-upa)w|Uy5?L80+8rNRvA9Af9{#94vV*tKpa|1fSs!1I zktOO=G^zPkAJL|dXmdR>Lfekhh(T?Ni9XV*I-(>4aSR|7C!yaV%j0MI+8y^UY$(O6 zqIX-1r}~Jhb`0JNG$z}!p$d32n4W^Qyg`RXMri>^)-ZH6yQU*3#zMbRhD@>Vo-R#9 z-9y8a9#CV9wMw4d6QnB2IQ<#R{RfZMD{dfb(f1pzMm%A0M-w;2P-*!%hJF3E_P92N zp+aSdpA{i{2%Oy*A|3kFN>EigegRnh1&>%kQJ5BIsZO~7M6$?gwwuy^D|wruRTGr> zlPOvwf*N$2avBkKIVd;6Mhq2~^TGQ;13n0ptbv>;VsAN(u$afgruea1JIPuUQi-Kd zo1Po)tTstdq;P#?9TGDA37}{@3N3Bxww~2(b53{lZKka|rfqCDdON-2KlG;pQL5e}DON@w!CF&yG_MWkZlt3TS7bDR^34(F!ZHSjJQl`M78ZVg14XHQ8 zML-{rLv{cB9q{4pcz7duCO9;9^N+5@Wmx_UZ*`Geh0iVCrEM-zkw{7s>I)ALNFCV; znI0VeL2P_<{tZt=_Ea;#wEJKiQZKzd--6c!)hH=VkK!GX4;&$-*5N9NpQ3$FPeg0o z-V$vkEImZe_y388B8aB<7B_A@7&Vc_Awxd2WYHhcAho-!T}g=?Wpi%oAHr@i^A@gK z%6yZkP*$gJM3-)>+>NfqsTkG_{mwAwgL+0>qLs*cVXaG8%c8ZEtxI9OLCyx5psBtI zE2txzqa&L`J*L_N`Jl_$%hP?MZ;@61*kqYSjj}?>;!#d$=B~B|%XisPnIU_T2<0yK za%WnMU(X};n{QY1N|YvOOp-aWhNcgdlnUHy-7iDowuMdu}sMAUzHjqq(_(?Cn~@mD54WmTnzr9NF8DBQ8`P&BS@g zJEZd{`$=hSb)K0u#YXaw#w;j~vWgJ{8=9P=20(HO5+!wjFQ`J6^XoTyctfps;O1vw zR4Kz)jvmn3k|IWzn(IQVYKrl1pyQ8t z)Gg##utU1`g#kPb`hSt#>0}&J(5S*b@e{nHGyFy*XX%3yzQd4bSq#b*K|+}jOMOiQ z#Yyc1$*HvC8}zv`@gB*@CC;S6oxZ{emx?JW$eN026fy-XjW0V!tK)4S|0Y(1@5{re zSjIO+F+;;=Y4}qUV+LP;k`G(zQxUT35l{T$+{94TtGr;DQn>~>m!uq2N0HdCd{#l~ zCP)k9&Ibn++4|>3v6Y5oD4K}jWW%>?25~+49FV6ah@hMCLpmoCxBet4MpQ)MKt!_Y ef6#CY_|baSr0^jNxMMnt2!wy8O;c+Q0`^}~lmU8xxE}N=aikHKh(Yh=k30cO$L4+TF~o z@*xzW7DHw5!GQ+5kd~a%gdF;B^i-rEh*pdcy{t7VY(*u#Rk5TDai z4Clgt1y)fmq9)X!7F9!9>{d|iSq^s36IxuNsdMjp}ic99YXum zIOyqD@e;};cbJ0340d~nW^1ye1bmh3CBVbRyGKB5q6(_uGVHA)FbS)1`AtlOcC}!Y z@&(Z;Xre{*yl%>ds935d%BCvXvaJzOR#n~6E%;*G7mFps(QW9MQr<4ntjM)ow(e?o zG%_ygCaBUWS26C-5n(JFM3dF~@Le%Vs;07u%iwZBq=LJXd zc6!353Sr(&P<5g(QP8x4e0Sp0EAwATUw%D5VH0abQ=EyCqZ`zdS^K^l^IZqygb^$P z4^4xpqIN$@9<2|bX(Y}(&o>g&_2Bd_d_Zo>EEOfka%4lIa?v(4>KjUX4`nXXqJVFFP~Q1HsuN8zVqBrBvY?K9i)?|D0Eb z9aRK1FblK)M<4f-3Jd5vToHh3!A0eO1Bkgk0MG6Sh{t#n<$;|y!Vho-uL;GpiNfG*Y}U}JqgX|Iu6!D(qcA#X5`@UM2;)U|*02Dl zbI|F`9As)vfcUG5>Zy6}em=D^T2G&9967ZI5%31>Fz0vvh6bq?prv9a!2KXez}?us zCGOZwQRI%fgYR>O9ECM(^Hc%>)>|Fx%;w}}kd4%QGc{jb*bc{P=f8cp`fw}tDm+pT zkFhRfqHvHk z?C>E{LU*Kkd}nv7&|#743OsTS5OQU&zp___o&ol94<|we>W;&L<~N6(uHPzX4R_D) zrs)!TffqA9ZcL}TNgdf#G&inTh5>m9L8fj%URK=5yrU6$+0fjOQ-Wk7%$JPd;(IPO zDRV{iXyb-#nNY1fB$<5TK1n!!9A?sK5D=m5czQ$r>DZGK&Ed&Le5x6rY9ky?w+04l zm$&-{A79(LxOMx{^&hS``bL|5qxHxrKX;G00{H;6kWr?OnZN0iBkbEHCjmUhzOoKj zYYTzU$u@eAh@eEOdd1Vt7_`Im?XV4Yg_vax19xwN=#Cha2gtOi2XhhTc&)o+94;Z< zu^@s9X{`p{O>O4VS$KBGKzxBP|MA?2Gj^?bc0gW3n;4_@gbH*ZtOh_@C5d*sq!W!w zc_OPBf%F6JV%v+EcW7^;vw6+b?%8CneB6+W%c?BhI+)&_ucu)xWgh$q7bJLW2JWrf zwaN7hUjLxv2YU4qkR$;AH%Moh$VClkMA?tEXsIQr$?VAd$VK-olPsVgGF#_B08zp* zPw~@Bug9iejZHVk&Nau*0oI7kG(m=En&BBBPiU~!H&An0@&4L0HxGHYm`S;UMcrWD zu$JAZ4h}e$Wl+vwx90iztPxX+mi<`5(s(=%(WaS)JpRhSST9I+ira zY0kS-``xo%mavU@<(6fnS63sm-rncPdBE5K(BDC{0~q5LidOm6Lh&lUTIhJa`);A3 zdgpGTbiMm-_eAg~ZL}wbLU^c+Ks-OoU%y@Dufw80gi{RjMK4MY);{03xVEsqP>r?{ R1OGB;_XOCIcLhKFe*mdUb(8=A diff --git a/BLIP_MRI/project/utils/__pycache__/utils.cpython-39.pyc b/BLIP_MRI/project/utils/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 1f2a54661bc86f3437cf6bed169fda06fd6f099a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1778 zcmZux&2Jnv6t`#gV|KHfq#>k;7ML4d6zvuew^f8%gb0ENk&44g)yO+`cPE{lnb@AR ziMAIu61{Tg0LPs8bC7&3mz=qAfuB8_Y^v5HKilv5ef{3^ux__aFi!sZbMkkMkXQC_ zbAU&_#8f9BoN$_wQMJ+$#o0@}nLqMr#jB0#?oMOY95uo7dF`Bx0=^+yqRs2SdAz}! z=k?JV4|ofbj_C0Af{waa*RZZ*?O@&DUA_+aCZ~s_w|NQQ{kP zPs3T6>ZHIXizY>>`k}M(P5w-r2)P?38B8V0nb=+432|nr6pWuj9;YQ2++G|b;AyCm zGtpm$QmWlh#Vk#-iB;lxCc^Q`r)4{($SQ@Oge;Cr1jKA~kyb_s)Yk{faE;3s0l03(EbcME1E~o!$eF*M{U;=Bd4J$VnD^u2;}Cu`}&qnkXIT zvtrL0`UjJo-9u4hTPP6$x+0ghV*O#7gRl0mxl*~bQCrOt)xk_>o zZApLueSQIw-2k{=FsZ~{h*blG&~;j)XFE~m^j4s{Ys+%U4{%X_Xo+ibt6kyNTq1EL`T#P9SEVw-Y7<@?+}BpjXdwsy0I zm7TUASIu@rej=0{dmI?bcOd`*T`7*sL<$v6B;!3#u7Pi8QDx%NaUIYaPDCc27jo|& zF31g#v%4vqjX8^++-$b1YmdH0AN%;Lp|6-^Gl2>yIBd*vUZMid^bPyhwmydGCH6bO zURr-W@8F|Y&)n8n?&npzxSUrJufh(7RXTgNW@-l{paC8OrXQ?r1wqrE G1OI(*O diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_DDP_interactive.sh old mode 100644 new mode 100755 diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_T1_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_Blip_T1_DDP_interactive.sh old mode 100644 new mode 100755 diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh old mode 100644 new mode 100755 index f6c6ffd..f4f1765 --- a/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh +++ b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVa_T1_DDP_interactive.sh @@ -1,14 +1,14 @@ set +x -cd /pscratch/sd/h/heehaw/BLIP_MRI/project #TODO: Change to your own scratch space +cd /YOUR_PROJECT_DIRECTORY #TODO: Change to your own scratch space module load python #module load pytorch/1.13.1 module load cpe/23.03 -conda activate /pscratch/sd/h/heehaw/anaconda/BLIP_MRI #TODO: Change to your own conda env +conda activate BLIP_MRI_llava #TODO: Change to your own conda env # conda activate py39 # pip install timm #export MASTER_ADDR=`/bin/hostname -s` @@ -16,9 +16,9 @@ conda activate /pscratch/sd/h/heehaw/anaconda/BLIP_MRI #TODO: Change to your o #export MASTER_PORT=$(shuf -i 29500-65535 -n 1) export LIBRARY_PATH=$LD_LIBRARY_PATH -export TORCH_EXTENSIONS_DIR=/pscratch/sd/h/heehaw #TODO: Change to your own scratch space -export HF_HOME=/pscratch/sd/h/heehaw/huggingface #TODO: Change to your own scratch space -export TORCH_HOME=/pscratch/sd/h/heehaw/ #TODO: Change to your own scratch space +export TORCH_EXTENSIONS_DIR=/pscratch/sd/ #TODO: Change to your own scratch space +export HF_HOME=/pscratch/sd/ #TODO: Change to your own scratch space +export TORCH_HOME=/pscratch/sd/ #TODO: Change to your own scratch space # #recent version (24.3.30) From f3d4ec272c8451329e0162f93900ea9d417e7ba9 Mon Sep 17 00:00:00 2001 From: Sue0515 Date: Wed, 10 Dec 2025 04:24:08 -0800 Subject: [PATCH 2/3] feat: Add LLaVA-NeXT-Interleave comparison task pipeline - Add general comparison JSON generator for categorical/numerical tasks - Support complete data separation (inter-split and intra-split) - Add flexible attribute encoding and multi-dataset compatibility --- .gitignore | 5 + BLIP_MRI/README.md | 127 +++ ..._Deepspeed_joint_multiturn_comparison.yaml | 40 + .../generate_json_general_comparison_split.py | 892 +++++++++++++++++ .../generate_json_sex_comparison_split.py | 500 ++++++++++ ...taset_T1_LLaVaNextInterleave_comparison.py | 301 ++++++ ...VaNextInterleave_comparison_hf_joint_T1.py | 257 +++++ BLIP_MRI/project/model/Bblip_t5_interleave.py | 153 +++ .../Trainer_LLaVaNextInterleave_comparison.py | 893 ++++++++++++++++++ ..._LLaVaNextInterleave_T1_DDP_interactive.sh | 21 + 10 files changed, 3189 insertions(+) create mode 100644 BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml create mode 100644 BLIP_MRI/project/data/generate_json_general_comparison_split.py create mode 100644 BLIP_MRI/project/data/generate_json_sex_comparison_split.py create mode 100644 BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py create mode 100644 BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py create mode 100644 BLIP_MRI/project/model/Bblip_t5_interleave.py create mode 100644 BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py create mode 100755 BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh diff --git a/.gitignore b/.gitignore index fd400d2..fce2b0e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,10 @@ BLIP_MRI/project/dataset/__pychache__ BLIP_MRI/project/model/__pychache__ BLIP_MRI/project/utils/__pychache__ BLIP_MRI/project/*.json +BLIP_MRI/project/data/*.json +BLIP_MRI/project/check_metadata.py +BLIP_MRI/project/data/abcd +BLIP_MRI/project/data/gard +BLIP_MRI/project/data/ukb __pycache__/ *.pyc diff --git a/BLIP_MRI/README.md b/BLIP_MRI/README.md index 8b13789..271d5b6 100644 --- a/BLIP_MRI/README.md +++ b/BLIP_MRI/README.md @@ -1 +1,128 @@ +# LLaVA-NeXT-Interleave for Brain MRI Comparison Tasks + +Multi-image comparison framework using LLaVA-NeXT-Interleave for brain MRI analysis. + +--- + +## Overview + +**Architecture:** LLaVA-NeXT-Interleave (Qwen-0.5b) +**Task:** Reference-augmented comparison +**Format:** Multi-turn conversation with interleaved images + +**Example:** +``` +Turn 1 (Reference): +User: "Here is a brain scan from a male participant. " +Assistant: "Understood. I've analyzed the reference scan." + +Turn 2 (Query): +User: "Compare this scan with the reference. What is the sex?" +Assistant: "Based on the comparison, the subject is male." +``` + +--- + +## 1. Data Preparation + +### Generate Comparison JSON + +#### Sex(or any categorical attributes) Comparison (Categorical) + +```bash +python generate_json_comparison_split_general.py \ + --study_sample ABCD \ + --meta_path /path/to/phenotype.csv \ + --img_dir /path/to/images \ + --target_col sex # (can be other categorical attributes)\ + --task_type categorical \ + --output_dir ./ \ + --num_pairs 3 \ + --seed 1234 +``` + +**Output:** +- `data/ABCD_sex_comparison_tasks_train.json` +- `data/ABCD_sex_comparison_tasks_val.json` +- `data/ABCD_sex_comparison_tasks_test.json` + +#### BMI_sds(or any numerical attributes) Regression (Numerical) + +```bash +python generate_json_comparison_split_general.py \ + --study_sample ABCD \ + --meta_path /path/to/phenotype.csv \ + --img_dir /path/to/images \ + --target_col BMI_sds # (can be other numerical attributes) \ + --task_type numerical \ + --output_dir ./ \ + --num_pairs 3 \ + --seed 1234 +``` + +**Output:** +- `data/ABCD_BMI_sds_comparison_tasks_train.json` +- `data/ABCD_BMI_sds_comparison_tasks_val.json` +- `data/ABCD_BMI_sds_comparison_tasks_test.json` + +**Key Parameters:** +- `--num_pairs`: Number of references per query subject + +--- + +## 2. Data Split Logic + +**Complete Separation:** +- **Inter-split:** Train/Val/Test subjects do NOT overlap +- **Intra-split:** Query and Reference pools do NOT overlap within each split + +**Example (1000 subjects, 70/15/15 split):** +``` +Train: 700 subjects + ├─ Query: 350 subjects + └─ Reference: 350 subjects (different from query!) + +Val: 150 subjects + ├─ Query: 75 subjects + └─ Reference: 75 subjects + +Test: 150 subjects + ├─ Query: 75 subjects + └─ Reference: 75 subjects +``` + +**Why?** Test subjects NEVER appear in training (even as references) for true generalization test + +--- + +## 3. Training + +### Configure + +Edit `config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml`: + +```yaml +dataset: + target_col: "sex" # or "age", "diagnosis", etc. + + train_json: + - "./data/ABCD_sex_comparison_tasks_train.json" + val_json: + - "./data/ABCD_sex_comparison_tasks_val.json" + test_json: + - "./data/ABCD_sex_comparison_tasks_test.json" + +``` + +## 4. Troubleshooting + +**Error: File not found** +- Check image paths in JSON match actual files + +**Error: Image token mismatch** +- Ensure using updated `Trainer_LLaVA_Next_interleave.py` + +**Low metrics** +- Check data split: Are train/val/test balanced? +- Review generation logs for prediction quality diff --git a/BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml b/BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml new file mode 100644 index 0000000..499d177 --- /dev/null +++ b/BLIP_MRI/project/config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml @@ -0,0 +1,40 @@ +wandb: + API_KEY: "YOUR_API_KEY" + +seed: 1234 + +dataset: + # Target column + # Single-task: + target_col: "sex" # Options: 'sex', 'age', 'diagnosis', 'bmi', etc. + + # Multi-task (future support): + # target_col: ["sex", "age"] # List for multiple targets + + train_json: + - "./data/ABCD_sex_comparison_tasks_train.json" + val_json: + - "./data/ABCD_sex_comparison_tasks_val.json" + test_json: + - "./data/ABCD_sex_comparison_tasks_test.json" + + # Image size + img_size: [120, 120, 120] + + +model: + hf_name: "llava-hf/llava-interleave-qwen-0.5b-hf" + patch_size: [10, 10, 10] + + +trainer: + max_epochs: 50 + learning_rate: 0.00005 + warmup_steps: 500 + weight_decay: 0.01 + per_device_batch_size: 1 # Multi-image (reference + query) requires more memory + gradient_accumulation_steps: 4 + gradient_checkpointing: True + logging_steps: 1 + ckpt_dir: "./hf_results/{}/last.ckpt" + resume_training: False diff --git a/BLIP_MRI/project/data/generate_json_general_comparison_split.py b/BLIP_MRI/project/data/generate_json_general_comparison_split.py new file mode 100644 index 0000000..36da9e1 --- /dev/null +++ b/BLIP_MRI/project/data/generate_json_general_comparison_split.py @@ -0,0 +1,892 @@ +""" +Generate JSON files for Comparison Tasks + +Supports: +- Categorical tasks: sex, diagnosis, etc. +- Numerical tasks: age, BMI, glucose, etc. + +This script ensures complete separation: +- Inter-split: Train/Val/Test subjects do NOT overlap +- Intra-split: Within each split, Query and Reference pools do NOT overlap +""" + +import os +import json +import pandas as pd +import glob +import numpy as np +from pathlib import Path +import random + +def load_subjects_and_images(meta_path, img_dir, subject_id_col, target_col, study_sample='ABCD'): + """Load metadata and available images""" + + # Load metadata + meta = pd.read_csv(meta_path) + meta = meta[[subject_id_col, target_col]].dropna() + + # Load available images + image_files = glob.glob(os.path.join(img_dir, '*.nii.gz')) + image_dict = {} + + suffix_len = -7 # Remove '.nii.gz' + + for img_path in image_files: + filename = os.path.basename(img_path) + subject_id = filename[:suffix_len] + image_dict[subject_id] = img_path + + if subject_id_col == 'subject_id': # GARD + # CSV의 subject_id 타입 확인 + if pd.api.types.is_integer_dtype(meta[subject_id_col]): + # subject_id가 int면 image_dict의 key도 int로 변환 + # _brain 같은 suffix 제거 + image_dict_converted = {} + for k, v in image_dict.items(): + # 숫자만 추출 (_brain 제거) + k_clean = k.replace('_brain', '') + try: + image_dict_converted[int(k_clean)] = v + except ValueError: + continue # 변환 실패하면 스킵 + image_dict = image_dict_converted + + # Filter subjects with both metadata and images + meta = meta[meta[subject_id_col].isin(image_dict.keys())].reset_index(drop=True) + + # Remap sex values to 0/1 if target_col is 'sex' + if 'sex' in target_col.lower(): + unique_values = meta[target_col].unique() + if set(unique_values).issubset({1, 2}): + print(f"Sex values are 1/2 format. Remapping: 1->0 (male), 2->1 (female)") + meta[target_col] = meta[target_col] - 1 + elif set(unique_values).issubset({'M', 'F', 'Male', 'Female', 'male', 'female'}): + print(f"Sex values are string format. Remapping: M/Male/male->0, F/Female/female->1") + meta[target_col] = meta[target_col].map({ + 'M': 0, 'Male': 0, 'male': 0, + 'F': 1, 'Female': 1, 'female': 1 + }) + elif not set(unique_values).issubset({0, 1}): + print(f"[WARN] Sex values are unexpected format: {unique_values}") + print(f" Expected: 0/1, 1/2, or M/F variants") + + return meta, image_dict + + +def detect_task_type(meta, target_col): + """ + Automatically detect if task is categorical or numerical + + Returns: + 'categorical' or 'numerical' + """ + unique_values = meta[target_col].unique() + + # Check if all values are numeric + if pd.api.types.is_numeric_dtype(meta[target_col]): + # If small number of unique values (< 10), likely categorical + if len(unique_values) <= 10: + return 'categorical' + else: + return 'numerical' + else: + return 'categorical' + + +def parse_categorical_mapping(meta, target_col, mapping_str=None): + """ + Parse categorical mapping from string or auto-detect + + Args: + meta: DataFrame + target_col: Target column name + mapping_str: Optional mapping string like "male=0,female=1" or "1=male,2=female" + + Returns: + value_to_label: dict mapping numeric values to string labels + label_to_value: dict mapping string labels to numeric values + """ + + if mapping_str: + # Parse user-provided mapping + # Supports: "male=0,female=1" or "0=male,1=female" + pairs = mapping_str.split(',') + value_to_label = {} + label_to_value = {} + + for pair in pairs: + parts = pair.strip().split('=') + if len(parts) == 2: + key, val = parts[0].strip(), parts[1].strip() + # Try to determine which is numeric + try: + num_val = int(key) + str_label = val + except ValueError: + num_val = int(val) + str_label = key + + value_to_label[num_val] = str_label + label_to_value[str_label] = num_val + else: + # Auto-detect from data + unique_values = sorted(meta[target_col].unique()) + + # Special handling for 'sex' column + if 'sex' in target_col.lower(): + # Common sex encoding: 0/1 or 1/2 + if set(unique_values) == {0, 1}: + value_to_label = {0: "male", 1: "female"} + label_to_value = {"male": 0, "female": 1} + print(" Detected sex column with 0/1 encoding (0=male, 1=female)") + elif set(unique_values) == {1, 2}: + # Will be remapped to 0/1 later + value_to_label = {1: "male", 2: "female"} + label_to_value = {"male": 1, "female": 2} + print(" Detected sex column with 1/2 encoding (1=male, 2=female)") + else: + # Fallback for unexpected values + value_to_label = {val: str(val) for val in unique_values} + label_to_value = {str(val): val for val in unique_values} + # Check if string values (use as-is) + elif not pd.api.types.is_numeric_dtype(meta[target_col]): + # String categorical values - use original values as labels + value_to_label = {val: str(val) for val in unique_values} + label_to_value = {str(val): val for val in unique_values} + print(f" Using original string values as labels: {list(unique_values)}") + # Check if already 0-indexed integers + elif set(unique_values) == set(range(len(unique_values))): + # Use generic labels + value_to_label = {i: f"class_{i}" for i in unique_values} + label_to_value = {f"class_{i}": i for i in unique_values} + # Check if 1-indexed + elif set(unique_values) == set(range(1, len(unique_values) + 1)): + # Remap to 0-indexed with generic labels + value_to_label = {i: f"class_{i-1}" for i in unique_values} + label_to_value = {f"class_{i-1}": i for i in unique_values} + else: + # Mixed values, use as-is + value_to_label = {val: str(val) for val in unique_values} + label_to_value = {str(val): val for val in unique_values} + + return value_to_label, label_to_value + + +def remap_categorical_values(meta, target_col, value_to_label): + """ + Remap categorical values to 0-indexed if needed + + Returns: + meta: DataFrame with remapped values + value_to_label: Updated mapping + """ + + # Check if values need remapping + unique_values = sorted(meta[target_col].unique()) + + if set(unique_values) == set(value_to_label.keys()): + # Already correct, but might need 0-indexing + if min(unique_values) == 1: + # 1-indexed → 0-indexed + print(f"Remapping {target_col} from 1-indexed to 0-indexed:") + for old_val in unique_values: + new_val = old_val - 1 + label = value_to_label[old_val] + print(f" {old_val} ({label}) → {new_val}") + meta[target_col] = meta[target_col].replace(old_val, new_val) + + # Update mapping + new_value_to_label = {old_val - 1: label for old_val, label in value_to_label.items()} + return meta, new_value_to_label + + return meta, value_to_label + + +def split_subjects_categorical(meta, subject_id_col, target_col, value_to_label, + train_ratio=0.7, val_ratio=0.15, seed=1234): + """ + Split subjects for categorical tasks (stratified by class) with COMPLETE SEPARATION + + Returns a dictionary with 6 subject pools: + - train_query, train_ref + - val_query, val_ref + - test_query, test_ref + + This ensures: + 1. Inter-split: Train/Val/Test subjects don't overlap + 2. Intra-split: Query and Reference pools don't overlap within each split + """ + + random.seed(seed) + np.random.seed(seed) + + train_subjects = [] + val_subjects = [] + test_subjects = [] + + # First split by class (stratified) + for value, label in value_to_label.items(): + class_subjects = meta[meta[target_col] == value][subject_id_col].values.tolist() + random.shuffle(class_subjects) + + n = len(class_subjects) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_subjects.extend(class_subjects[:n_train]) + val_subjects.extend(class_subjects[n_train:n_train+n_val]) + test_subjects.extend(class_subjects[n_train+n_val:]) + + print(f" {label} (value={value}): {n} subjects") + print(f" Train: {n_train}, Val: {n_val}, Test: {n - n_train - n_val}") + + # Further split each set into query and reference (50/50) + def split_query_ref(subjects_list): + """Split subjects into query and reference pools""" + random.shuffle(subjects_list) + n = len(subjects_list) + query = subjects_list[:n//2] + ref = subjects_list[n//2:] + return query, ref + + train_query, train_ref = split_query_ref(train_subjects) + val_query, val_ref = split_query_ref(val_subjects) + test_query, test_ref = split_query_ref(test_subjects) + + print(f"\n Query/Reference split:") + print(f" Train: Query={len(train_query)}, Ref={len(train_ref)}") + print(f" Val: Query={len(val_query)}, Ref={len(val_ref)}") + print(f" Test: Query={len(test_query)}, Ref={len(test_ref)}") + + return { + 'train_query': train_query, + 'train_ref': train_ref, + 'val_query': val_query, + 'val_ref': val_ref, + 'test_query': test_query, + 'test_ref': test_ref + } + + +def split_subjects_numerical(meta, subject_id_col, target_col, + train_ratio=0.7, val_ratio=0.15, seed=1234): + """ + Split subjects for numerical tasks (stratified by value bins) with COMPLETE SEPARATION + + Returns a dictionary with 6 subject pools: + - train_query, train_ref + - val_query, val_ref + - test_query, test_ref + + This ensures: + 1. Inter-split: Train/Val/Test subjects don't overlap + 2. Intra-split: Query and Reference pools don't overlap within each split + """ + + random.seed(seed) + np.random.seed(seed) + + # Bin values into quartiles for stratification + meta['_bin'] = pd.qcut(meta[target_col], q=4, labels=False, duplicates='drop') + + train_subjects = [] + val_subjects = [] + test_subjects = [] + + # First split each bin into train/val/test + for bin_idx in sorted(meta['_bin'].unique()): + bin_subjects = meta[meta['_bin'] == bin_idx][subject_id_col].values.tolist() + random.shuffle(bin_subjects) + + n = len(bin_subjects) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_subjects.extend(bin_subjects[:n_train]) + val_subjects.extend(bin_subjects[n_train:n_train+n_val]) + test_subjects.extend(bin_subjects[n_train+n_val:]) + + bin_meta = meta[meta['_bin'] == bin_idx] + print(f" Bin {bin_idx} ({bin_meta[target_col].min():.1f}-{bin_meta[target_col].max():.1f}): {n} subjects") + print(f" Train: {n_train}, Val: {n_val}, Test: {n - n_train - n_val}") + + meta = meta.drop('_bin', axis=1) + + # Further split each set into query and reference (50/50) + def split_query_ref(subjects_list): + """Split subjects into query and reference pools""" + random.shuffle(subjects_list) + n = len(subjects_list) + query = subjects_list[:n//2] + ref = subjects_list[n//2:] + return query, ref + + train_query, train_ref = split_query_ref(train_subjects) + val_query, val_ref = split_query_ref(val_subjects) + test_query, test_ref = split_query_ref(test_subjects) + + print(f"\n Query/Reference split:") + print(f" Train: Query={len(train_query)}, Ref={len(train_ref)}") + print(f" Val: Query={len(val_query)}, Ref={len(val_ref)}") + print(f" Test: Query={len(test_query)}, Ref={len(test_ref)}") + + return { + 'train_query': train_query, + 'train_ref': train_ref, + 'val_query': val_query, + 'val_ref': val_ref, + 'test_query': test_query, + 'test_ref': test_ref + } + + +def generate_comparison_tasks_categorical( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + target_col, + value_to_label, + num_pairs_per_subject=5, + same_class_ratio=0.5, + seed=1234 +): + """Generate comparison tasks for categorical target""" + + random.seed(seed) + + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + + # Group reference subjects by class + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + ref_by_class = {} + for value in value_to_label.keys(): + ref_by_class[value] = ref_meta[ref_meta[target_col] == value][subject_id_col].values.tolist() + + print(f"\nReference pool: {len(reference_subjects)} subjects") + for value, label in value_to_label.items(): + print(f" {label}: {len(ref_by_class[value])}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + query_value = int(row[target_col]) + query_label = value_to_label[query_value] + query_img_path = image_dict[query_id] + + # Determine same-class vs different-class pairs + num_same = int(num_pairs_per_subject * same_class_ratio) + num_diff = num_pairs_per_subject - num_same + + # Sample same-class references + same_pool = [s for s in ref_by_class[query_value] if s != query_id] + if len(same_pool) >= num_same: + same_refs = random.sample(same_pool, num_same) + else: + same_refs = same_pool + + # Sample different-class references + diff_pool = [] + for value in value_to_label.keys(): + if value != query_value: + diff_pool.extend(ref_by_class[value]) + + if len(diff_pool) >= num_diff: + diff_refs = random.sample(diff_pool, num_diff) + else: + diff_refs = diff_pool + + # Create tasks for same-class + for ref_id in same_refs: + ref_value = query_value + ref_label = query_label + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='same', + target_name=target_col + ) + all_tasks.append(task) + + # Create tasks for different-class + for ref_id in diff_refs: + ref_value = int(meta[meta[subject_id_col] == ref_id][target_col].values[0]) + ref_label = value_to_label[ref_value] + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='different', + target_name=target_col + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + return all_tasks + +def generate_comparison_tasks_categorical( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + target_col, + value_to_label, + num_pairs_per_subject=5, + same_class_ratio=0.5, + seed=1234 + ): + """Generate comparison tasks for categorical target""" + + random.seed(seed) + + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + + # Group reference subjects by class + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + ref_by_class = {} + for value in value_to_label.keys(): + ref_by_class[value] = ref_meta[ref_meta[target_col] == value][subject_id_col].values.tolist() + + print(f"\nReference pool: {len(reference_subjects)} subjects") + for value, label in value_to_label.items(): + print(f" {label}: {len(ref_by_class[value])}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + # Convert to Python native types + if isinstance(query_id, (np.integer, np.int64)): + query_id = int(query_id) + + query_value = int(row[target_col]) + query_label = value_to_label[query_value] + query_img_path = image_dict[query_id] + + # Determine same-class vs different-class pairs + num_same = int(num_pairs_per_subject * same_class_ratio) + num_diff = num_pairs_per_subject - num_same + + # Sample same-class references + same_pool = [s for s in ref_by_class[query_value] if s != query_id] + if len(same_pool) >= num_same: + same_refs = random.sample(same_pool, num_same) + else: + same_refs = same_pool + + # Sample different-class references + diff_pool = [] + for value in value_to_label.keys(): + if value != query_value: + diff_pool.extend(ref_by_class[value]) + + if len(diff_pool) >= num_diff: + diff_refs = random.sample(diff_pool, num_diff) + else: + diff_refs = diff_pool + + # Create tasks for same-class + for ref_id in same_refs: + # Convert to Python native types + if isinstance(ref_id, (np.integer, np.int64)): + ref_id = int(ref_id) + + ref_value = query_value + ref_label = query_label + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='same', + target_name=target_col + ) + all_tasks.append(task) + + # Create tasks for different-class + for ref_id in diff_refs: + # Convert to Python native types + if isinstance(ref_id, (np.integer, np.int64)): + ref_id = int(ref_id) + + ref_value = int(meta[meta[subject_id_col] == ref_id][target_col].values[0]) + ref_label = value_to_label[ref_value] + ref_img_path = image_dict[ref_id] + + task = create_task_categorical( + query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type='different', + target_name=target_col + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + return all_tasks + +def generate_comparison_tasks_numerical( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + target_col, + num_pairs_per_subject=6, + seed=1234 + ): + """Generate comparison tasks for numerical target (e.g., age)""" + + random.seed(seed) + + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + + print(f"\nReference pool: {len(reference_subjects)} subjects") + print(f" {target_col} range: {ref_meta[target_col].min():.1f} - {ref_meta[target_col].max():.1f}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + # Convert to Python native types + if isinstance(query_id, (np.integer, np.int64)): + query_id = int(query_id) + + query_value = float(row[target_col]) + query_img_path = image_dict[query_id] + + # Sample references across different value ranges + ref_pool = [s for s in reference_subjects if s != query_id] + + if len(ref_pool) >= num_pairs_per_subject: + selected_refs = random.sample(ref_pool, num_pairs_per_subject) + else: + selected_refs = ref_pool + + for ref_id in selected_refs: + # Convert to Python native types + if isinstance(ref_id, (np.integer, np.int64)): + ref_id = int(ref_id) + + ref_value = float(meta[meta[subject_id_col] == ref_id][target_col].values[0]) + ref_img_path = image_dict[ref_id] + + task = create_task_numerical( + query_id, query_value, query_img_path, + ref_id, ref_value, ref_img_path, + target_name=target_col + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + return all_tasks + + +def create_task_categorical(query_id, query_value, query_label, query_img_path, + ref_id, ref_value, ref_label, ref_img_path, + comparison_type, target_name): + """Create task for categorical target""" + + task_id = f"{query_id}_{comparison_type}_{target_name}_comparison" + assistant_reasoning = ( + f"Based on comparison with the reference scan, this appears to be {query_label}." + ) + + task = { + "task_id": task_id, + "task_type": "T1", + "subject_ids": [ref_id, query_id], + "modalities": ["sMRI", "sMRI"], + "images": [ + {"path": ref_img_path, "token": "", "modality": "sMRI"}, + {"path": query_img_path, "token": "", "modality": "sMRI"} + ], + "conversations": [ + { + "role": "user", + "content": [ + {"type": "text", "text": f"Here is a T1-weighted brain MRI from a {ref_label} participant. This will serve as your reference scan."}, + {"type": "image", "modality": "sMRI", "image_path": ref_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": f"Understood. I've analyzed the reference {ref_label} brain scan."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": f"Compare this brain scan with the reference. What is the {target_name}?"}, + {"type": "image", "modality": "sMRI", "image_path": query_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": assistant_reasoning}] + } + ], + "metadata": { + "subject_id": query_id, + "subject_label": query_label, + "subject_label_numeric": int(query_value), + "reference_id": ref_id, + "reference_label": ref_label, + "reference_label_numeric": int(ref_value), + "comparison_type": comparison_type, + "task": f"{target_name}_classification_via_comparison", + "target_name": target_name, + "task_type": "categorical" + } + } + + return task + + +def create_task_numerical(query_id, query_value, query_img_path, + ref_id, ref_value, ref_img_path, target_name): + """Create task for numerical target""" + + task_id = f"{query_id}_{target_name}_comparison" + + assistant_reasoning = ( + f"I estimate this subject's {target_name} to be approximately {query_value:.1f}." + ) + + task = { + "task_id": task_id, + "task_type": "T1", + "subject_ids": [ref_id, query_id], + "modalities": ["sMRI", "sMRI"], + "images": [ + {"path": ref_img_path, "token": "", "modality": "sMRI"}, + {"path": query_img_path, "token": "", "modality": "sMRI"} + ], + "conversations": [ + { + "role": "user", + "content": [ + {"type": "text", "text": f"Here is a T1-weighted brain MRI from a participant with {target_name}: {ref_value:.1f}. This will serve as your reference scan."}, + {"type": "image", "modality": "sMRI", "image_path": ref_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": f"Understood. I've analyzed the reference brain scan ({target_name}: {ref_value:.1f})."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": f"Compare this brain scan with the reference. What is the {target_name}?"}, + {"type": "image", "modality": "sMRI", "image_path": query_img_path} + ] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": assistant_reasoning}] + } + ], + "metadata": { + "subject_id": query_id, + "subject_value": float(query_value), + "reference_id": ref_id, + "reference_value": float(ref_value), + "task": f"{target_name}_regression_via_comparison", + "target_name": target_name, + "task_type": "numerical" + } + } + + return task + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description='Generate comparison task JSON with proper split (GENERAL: categorical or numerical)' + ) + parser.add_argument('--study_sample', type=str, default='ABCD', choices=['ABCD', 'UKB', 'GARD'], + help='Study sample name') + parser.add_argument('--meta_path', type=str, required=True, + help='Path to phenotype CSV file') + parser.add_argument('--img_dir', type=str, required=True, + help='Directory containing MRI images') + parser.add_argument('--target_col', type=str, required=True, + help='Target column name (e.g., sex, age, BMI)') + parser.add_argument('--task_type', type=str, default='auto', choices=['auto', 'categorical', 'numerical'], + help='Task type (auto-detect if not specified)') + parser.add_argument('--categorical_mapping', type=str, default=None, + help='Categorical mapping (e.g., "male=0,female=1" or "1=male,2=female")') + parser.add_argument('--output_dir', type=str, default='./data', + help='Output directory for JSON files') + parser.add_argument('--output_prefix', type=str, default=None, + help='Output file prefix (default: {study_sample}_{target_col}_comparison_tasks)') + parser.add_argument('--subject_id_col', type=str, default=None, + help='Subject ID column name (default: subjectkey for ABCD, eid for UKB)') + parser.add_argument('--num_pairs', type=int, default=5, + help='Number of comparison pairs per query subject') + parser.add_argument('--same_class_ratio', type=float, default=0.5, + help='Ratio of same-class comparisons (categorical only)') + parser.add_argument('--train_ratio', type=float, default=0.7, + help='Train set ratio') + parser.add_argument('--val_ratio', type=float, default=0.15, + help='Validation set ratio') + parser.add_argument('--seed', type=int, default=1234, + help='Random seed') + + args = parser.parse_args() + + # Set defaults + # if args.subject_id_col is None: + # args.subject_id_col = 'subjectkey' if args.study_sample == 'ABCD' else 'eid' + if args.subject_id_col is None: + if args.study_sample == 'ABCD': + args.subject_id_col = 'subjectkey' + elif args.study_sample == 'UKB': + args.subject_id_col = 'eid' + elif args.study_sample == 'GARD': + args.subject_id_col = 'subject_id' # ← 이 부분 추가 + + else: + print("[WARN] Unknown study_sample. Please specify--subject_id_col manually.") + args.subject_id_col = 'subject_id' + + if args.output_prefix is None: + args.output_prefix = f"{args.study_sample}_{args.target_col}_comparison_tasks" + + print("=" * 70) + print(f"GENERATING {args.target_col.upper()} COMPARISON TASKS WITH PROPER SPLIT") + print("=" * 70) + print(f"Study: {args.study_sample}") + print(f"Target: {args.target_col}") + print(f"Metadata: {args.meta_path}") + print(f"Images: {args.img_dir}") + print("=" * 70) + + # Load data + print("\n[Step 1] Loading subjects and images...") + meta, image_dict = load_subjects_and_images( + args.meta_path, args.img_dir, args.subject_id_col, args.target_col, args.study_sample + ) + print(f"Loaded {len(meta)} subjects with images") + + # Detect task type + if args.task_type == 'auto': + task_type = detect_task_type(meta, args.target_col) + print(f"\n[Step 2] Auto-detected task type: {task_type}") + else: + task_type = args.task_type + print(f"\n[Step 2] Task type: {task_type}") + + # Split subjects + if task_type == 'categorical': + value_to_label, label_to_value = parse_categorical_mapping(meta, args.target_col, args.categorical_mapping) + print(f"\nCategorical mapping:") + for value, label in sorted(value_to_label.items()): + print(f" {value} → {label}") + + meta, value_to_label = remap_categorical_values(meta, args.target_col, value_to_label) + + print("\n[Step 3] Splitting subjects (stratified by class) with COMPLETE SEPARATION...") + splits = split_subjects_categorical( + meta, args.subject_id_col, args.target_col, value_to_label, + args.train_ratio, args.val_ratio, args.seed + ) + else: + print(f"\n{args.target_col} range: {meta[args.target_col].min():.1f} - {meta[args.target_col].max():.1f}") + print("\n[Step 3] Splitting subjects (stratified by value bins) with COMPLETE SEPARATION...") + splits = split_subjects_numerical( + meta, args.subject_id_col, args.target_col, + args.train_ratio, args.val_ratio, args.seed + ) + + print(f"\nTotal subjects:") + print(f" Train: {len(splits['train_query']) + len(splits['train_ref'])}") + print(f" Val: {len(splits['val_query']) + len(splits['val_ref'])}") + print(f" Test: {len(splits['test_query']) + len(splits['test_ref'])}") + + # Generate tasks + os.makedirs(args.output_dir, exist_ok=True) + + if task_type == 'categorical': + # Train: query from train_query, reference from train_ref (COMPLETE SEPARATION) + print("\nGenerating TRAIN tasks (categorical)...") + train_tasks = generate_comparison_tasks_categorical( + splits['train_query'], splits['train_ref'], meta, image_dict, + args.subject_id_col, args.target_col, value_to_label, + args.num_pairs, args.same_class_ratio, args.seed + ) + # Val: query from val_query, reference from val_ref (COMPLETE SEPARATION) + print("\nGenerating VAL tasks (categorical)...") + val_tasks = generate_comparison_tasks_categorical( + splits['val_query'], splits['val_ref'], meta, image_dict, + args.subject_id_col, args.target_col, value_to_label, + args.num_pairs, args.same_class_ratio, args.seed + 1 + ) + # Test: query from test_query, reference from test_ref (COMPLETE SEPARATION) + print("\nGenerating TEST tasks (categorical)...") + test_tasks = generate_comparison_tasks_categorical( + splits['test_query'], splits['test_ref'], meta, image_dict, + args.subject_id_col, args.target_col, value_to_label, + args.num_pairs, args.same_class_ratio, args.seed + 2 + ) + else: + # Train: query from train_query, reference from train_ref (COMPLETE SEPARATION) + print("\nGenerating TRAIN tasks (numerical)...") + train_tasks = generate_comparison_tasks_numerical( + splits['train_query'], splits['train_ref'], meta, image_dict, + args.subject_id_col, args.target_col, args.num_pairs, args.seed + ) + # Val: query from val_query, reference from val_ref (COMPLETE SEPARATION) + print("\nGenerating VAL tasks (numerical)...") + val_tasks = generate_comparison_tasks_numerical( + splits['val_query'], splits['val_ref'], meta, image_dict, + args.subject_id_col, args.target_col, args.num_pairs, args.seed + 1 + ) + # Test: query from test_query, reference from test_ref (COMPLETE SEPARATION) + print("\nGenerating TEST tasks (numerical)...") + test_tasks = generate_comparison_tasks_numerical( + splits['test_query'], splits['test_ref'], meta, image_dict, + args.subject_id_col, args.target_col, args.num_pairs, args.seed + 2 + ) + + # Save + train_path = os.path.join(args.output_dir, f"{args.output_prefix}_train.json") + val_path = os.path.join(args.output_dir, f"{args.output_prefix}_val.json") + test_path = os.path.join(args.output_dir, f"{args.output_prefix}_test.json") + + with open(train_path, 'w') as f: + json.dump(train_tasks, f, indent=2) + with open(val_path, 'w') as f: + json.dump(val_tasks, f, indent=2) + with open(test_path, 'w') as f: + json.dump(test_tasks, f, indent=2) + + print(f"\n✓ Saved: {train_path}") + print(f"✓ Saved: {val_path}") + print(f"✓ Saved: {test_path}") + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"Task type: {task_type}") + print(f"Target: {args.target_col}") + print(f"Train tasks: {len(train_tasks)}") + print(f"Val tasks: {len(val_tasks)}") + print(f"Test tasks: {len(test_tasks)}") + print(f"Total: {len(train_tasks) + len(val_tasks) + len(test_tasks)}") + print("=" * 70) + + print("\nSample task:") + print(json.dumps(train_tasks[0], indent=2)) + + +if __name__ == '__main__': + main() diff --git a/BLIP_MRI/project/data/generate_json_sex_comparison_split.py b/BLIP_MRI/project/data/generate_json_sex_comparison_split.py new file mode 100644 index 0000000..8f6c90f --- /dev/null +++ b/BLIP_MRI/project/data/generate_json_sex_comparison_split.py @@ -0,0 +1,500 @@ +""" +Generate JSON files for Sex Comparison Task (Multi-turn Conversation) + +Task: Given a reference image with known sex, predict query image's sex through comparison +Format: 2-turn conversation +- Turn 1: User provides reference image + sex label -> Assistant acknowledges +- Turn 2: User provides query image + asks comparison -> Assistant predicts sex +""" + +import os +import json +import pandas as pd +import glob +import numpy as np +from pathlib import Path +import random + + +def load_subjects_and_images(meta_path, img_dir, subject_id_col, sex_col, study_sample='ABCD'): + """Load metadata and available images""" + + # Load metadata + meta = pd.read_csv(meta_path) + meta = meta[[subject_id_col, sex_col]].dropna() + + # Load available images + image_files = glob.glob(os.path.join(img_dir, '*.nii.gz')) + image_dict = {} + + # Determine suffix length based on study sample + suffix_len = -7 # Remove '.nii.gz' + + for img_path in image_files: + filename = os.path.basename(img_path) + subject_id = filename[:suffix_len] + image_dict[subject_id] = img_path + + # Filter subjects with both metadata and images + meta = meta[meta[subject_id_col].isin(image_dict.keys())].reset_index(drop=True) + + # Remap sex values to 0/1 if needed + unique_sex_values = meta[sex_col].unique() + if set(unique_sex_values).issubset({1, 2}): + meta[sex_col] = meta[sex_col] - 1 + + return meta, image_dict + + +def split_subjects(meta, subject_id_col, sex_col, train_ratio=0.7, val_ratio=0.15, seed=1234): + """ + Split subjects into train/val/test with COMPLETE SEPARATION + + Each split is further divided into query and reference pools (50/50) + to ensure no subject appears as both query and reference. + + Returns: + Dictionary with keys: + - train_query, train_ref + - val_query, val_ref + - test_query, test_ref + """ + + random.seed(seed) + np.random.seed(seed) + + # Separate by sex + males = meta[meta[sex_col] == 0][subject_id_col].values.tolist() + females = meta[meta[sex_col] == 1][subject_id_col].values.tolist() + + random.shuffle(males) + random.shuffle(females) + + # Split males into train/val/test + n_males = len(males) + n_train_males = int(n_males * train_ratio) + n_val_males = int(n_males * val_ratio) + + train_males = males[:n_train_males] + val_males = males[n_train_males:n_train_males+n_val_males] + test_males = males[n_train_males+n_val_males:] + + # Split females into train/val/test + n_females = len(females) + n_train_females = int(n_females * train_ratio) + n_val_females = int(n_females * val_ratio) + + train_females = females[:n_train_females] + val_females = females[n_train_females:n_train_females+n_val_females] + test_females = females[n_train_females+n_val_females:] + + # Further split each set into query and reference (50/50) + def split_query_ref(males_list, females_list): + """Split into query and reference pools""" + # Males + n_m = len(males_list) + query_males = males_list[:n_m//2] + ref_males = males_list[n_m//2:] + + # Females + n_f = len(females_list) + query_females = females_list[:n_f//2] + ref_females = females_list[n_f//2:] + + query = query_males + query_females + ref = ref_males + ref_females + + return query, ref + + train_query, train_ref = split_query_ref(train_males, train_females) + val_query, val_ref = split_query_ref(val_males, val_females) + test_query, test_ref = split_query_ref(test_males, test_females) + + # Print summary + print(f"Total subjects: {len(meta)}") + print(f" Males: {len(males)}, Females: {len(females)}") + + print(f"\nTrain: {len(train_males) + len(train_females)} total") + print(f" Query: {len(train_query)} (Males: {len([s for s in train_query if s in males])}, Females: {len([s for s in train_query if s in females])})") + print(f" Reference: {len(train_ref)} (Males: {len([s for s in train_ref if s in males])}, Females: {len([s for s in train_ref if s in females])})") + + print(f"\nVal: {len(val_males) + len(val_females)} total") + print(f" Query: {len(val_query)} (Males: {len([s for s in val_query if s in males])}, Females: {len([s for s in val_query if s in females])})") + print(f" Reference: {len(val_ref)} (Males: {len([s for s in val_ref if s in males])}, Females: {len([s for s in val_ref if s in females])})") + + print(f"\nTest: {len(test_males) + len(test_females)} total") + print(f" Query: {len(test_query)} (Males: {len([s for s in test_query if s in males])}, Females: {len([s for s in test_query if s in females])})") + print(f" Reference: {len(test_ref)} (Males: {len([s for s in test_ref if s in males])}, Females: {len([s for s in test_ref if s in females])})") + + return { + 'train_query': train_query, + 'train_ref': train_ref, + 'val_query': val_query, + 'val_ref': val_ref, + 'test_query': test_query, + 'test_ref': test_ref + } + + +def generate_comparison_tasks( + query_subjects, + reference_subjects, + meta, + image_dict, + subject_id_col, + sex_col, + num_pairs_per_subject=5, + same_sex_ratio=0.5, + seed=1234 +): + """ + Generate comparison tasks + + Args: + query_subjects: List of subjects to use as queries + reference_subjects: List of subjects to use as references + meta: Full metadata DataFrame + image_dict: Dict mapping subject_id to image path + num_pairs_per_subject: Number of reference pairs per query + same_sex_ratio: Ratio of same-sex vs different-sex comparisons + """ + + random.seed(seed) + + # Filter metadata to query subjects + query_meta = meta[meta[subject_id_col].isin(query_subjects)].reset_index(drop=True) + + # Separate reference subjects by sex + ref_meta = meta[meta[subject_id_col].isin(reference_subjects)] + ref_males = ref_meta[ref_meta[sex_col] == 0][subject_id_col].values.tolist() + ref_females = ref_meta[ref_meta[sex_col] == 1][subject_id_col].values.tolist() + + print(f"\nReference pool: {len(reference_subjects)} subjects") + print(f" Males: {len(ref_males)}, Females: {len(ref_females)}") + + all_tasks = [] + + for _, row in query_meta.iterrows(): + query_id = row[subject_id_col] + query_sex = int(row[sex_col]) + query_sex_label = 'male' if query_sex == 0 else 'female' + query_img_path = image_dict[query_id] + + # Determine how many same-sex vs different-sex pairs + num_same = int(num_pairs_per_subject * same_sex_ratio) + num_diff = num_pairs_per_subject - num_same + + # Sample reference subjects (exclude query itself if in reference pool) + if query_sex == 0: # Query is male + same_pool = [s for s in ref_males if s != query_id] + diff_pool = ref_females + else: # Query is female + same_pool = [s for s in ref_females if s != query_id] + diff_pool = ref_males + + # Sample same-sex references + if len(same_pool) >= num_same: + same_refs = random.sample(same_pool, num_same) + else: + same_refs = same_pool + if len(same_refs) < num_same: + print(f"Warning: Query {query_id} has only {len(same_refs)} same-sex references (requested {num_same})") + + # Sample different-sex references + if len(diff_pool) >= num_diff: + diff_refs = random.sample(diff_pool, num_diff) + else: + diff_refs = diff_pool + if len(diff_refs) < num_diff: + print(f"Warning: Query {query_id} has only {len(diff_refs)} different-sex references (requested {num_diff})") + + # Create tasks for same-sex comparisons + for ref_id in same_refs: + ref_sex = query_sex + ref_sex_label = query_sex_label + ref_img_path = image_dict[ref_id] + + task = create_task( + query_id=query_id, + query_sex=query_sex, + query_sex_label=query_sex_label, + query_img_path=query_img_path, + ref_id=ref_id, + ref_sex=ref_sex, + ref_sex_label=ref_sex_label, + ref_img_path=ref_img_path, + comparison_type='same' + ) + all_tasks.append(task) + + # Create tasks for different-sex comparisons + for ref_id in diff_refs: + ref_sex = 1 - query_sex + ref_sex_label = 'female' if ref_sex == 1 else 'male' + ref_img_path = image_dict[ref_id] + + task = create_task( + query_id=query_id, + query_sex=query_sex, + query_sex_label=query_sex_label, + query_img_path=query_img_path, + ref_id=ref_id, + ref_sex=ref_sex, + ref_sex_label=ref_sex_label, + ref_img_path=ref_img_path, + comparison_type='different' + ) + all_tasks.append(task) + + print(f"Generated {len(all_tasks)} comparison tasks") + + return all_tasks + + +def create_task(query_id, query_sex, query_sex_label, query_img_path, + ref_id, ref_sex, ref_sex_label, ref_img_path, comparison_type): + """Create a single comparison task in JSON format""" + + task_id = f"{query_id}_{comparison_type}_sex_comparison" + + # Generate assistant responses based on comparison + if comparison_type == 'same': + assistant_reasoning = ( + f"Based on comparison with the reference scan, this appears to be a {query_sex_label} subject. " + f"Structural similarities include comparable gray matter volumes and white matter distribution patterns " + f"typical of {query_sex_label} brain anatomy." + ) + else: + assistant_reasoning = ( + f"Based on comparison with the reference scan, this appears to be a {query_sex_label} subject. " + f"Despite being compared with a {ref_sex_label} reference, I observe distinct structural differences " + f"in gray matter distribution and white matter patterns characteristic of {query_sex_label} brain anatomy." + ) + + task = { + "task_id": task_id, + "task_type": "T1", + "subject_ids": [ref_id, query_id], + "modalities": ["sMRI", "sMRI"], + "images": [ + { + "path": ref_img_path, + "token": "", + "modality": "sMRI" + }, + { + "path": query_img_path, + "token": "", + "modality": "sMRI" + } + ], + "conversations": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"Here is a T1-weighted brain MRI from a {ref_sex_label} participant. This will serve as your reference scan." + }, + { + "type": "image", + "modality": "sMRI", + "image_path": ref_img_path + } + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": f"Understood. I've analyzed the reference {ref_sex_label} brain scan." + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Compare this brain scan with the reference. What is the likely biological sex of this subject?" + }, + { + "type": "image", + "modality": "sMRI", + "image_path": query_img_path + } + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": assistant_reasoning + } + ] + } + ], + "metadata": { + "subject_id": query_id, + "subject_label": query_sex_label, + "subject_label_numeric": query_sex, + "reference_id": ref_id, + "reference_label": ref_sex_label, + "reference_label_numeric": ref_sex, + "comparison_type": comparison_type, + "task": "sex_classification_via_comparison" + } + } + + return task + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description='Generate sex comparison task JSON with proper train/val/test split (NO DATA LEAKAGE)' + ) + parser.add_argument('--study_sample', type=str, default='ABCD', choices=['ABCD', 'UKB'], + help='Study sample name') + parser.add_argument('--meta_path', type=str, required=True, + help='Path to phenotype CSV file') + parser.add_argument('--img_dir', type=str, required=True, + help='Directory containing MRI images') + parser.add_argument('--output_dir', type=str, default='./data', + help='Output directory for JSON files') + parser.add_argument('--output_prefix', type=str, default='ABCD_sex_comparison_tasks', + help='Output file prefix') + parser.add_argument('--subject_id_col', type=str, default=None, + help='Subject ID column name (default: subjectkey for ABCD, eid for UKB)') + parser.add_argument('--sex_col', type=str, default='sex', + help='Sex column name') + parser.add_argument('--num_pairs', type=int, default=5, + help='Number of comparison pairs per query subject') + parser.add_argument('--same_sex_ratio', type=float, default=0.5, + help='Ratio of same-sex comparisons') + parser.add_argument('--train_ratio', type=float, default=0.7, + help='Train set ratio') + parser.add_argument('--val_ratio', type=float, default=0.15, + help='Validation set ratio') + parser.add_argument('--seed', type=int, default=1234, + help='Random seed') + + args = parser.parse_args() + + # Set default subject ID column if not specified + if args.subject_id_col is None: + if args.study_sample == 'ABCD': + args.subject_id_col = 'subjectkey' + elif args.study_sample == 'UKB': + args.subject_id_col = 'eid' + + print("=" * 70) + print("GENERATING SEX COMPARISON TASKS WITH PROPER SPLIT") + print("=" * 70) + print(f"Study: {args.study_sample}") + print(f"Metadata: {args.meta_path}") + print(f"Images: {args.img_dir}") + print(f"Output: {args.output_dir}/{args.output_prefix}_*.json") + print("=" * 70) + + # Load subjects and images + print("\n Loading subjects and images...") + meta, image_dict = load_subjects_and_images( + meta_path=args.meta_path, + img_dir=args.img_dir, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + study_sample=args.study_sample + ) + + # Split subjects into train/val/test with query/reference separation + print("\nSplitting subjects with COMPLETE SEPARATION...") + splits = split_subjects( + meta=meta, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + seed=args.seed + ) + + # Generate tasks for each split + os.makedirs(args.output_dir, exist_ok=True) + + # Train: query from train_query, reference from train_ref (NO OVERLAP!) + print("\nGenerating TRAIN tasks...") + train_tasks = generate_comparison_tasks( + query_subjects=splits['train_query'], + reference_subjects=splits['train_ref'], + meta=meta, + image_dict=image_dict, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + num_pairs_per_subject=args.num_pairs, + same_sex_ratio=args.same_sex_ratio, + seed=args.seed + ) + + train_path = os.path.join(args.output_dir, f"{args.output_prefix}_train.json") + with open(train_path, 'w') as f: + json.dump(train_tasks, f, indent=2) + print(f"✓ Saved: {train_path}") + + # Val: query from val_query, reference from val_ref (NO OVERLAP!) + print("\nGenerating VAL tasks...") + val_tasks = generate_comparison_tasks( + query_subjects=splits['val_query'], + reference_subjects=splits['val_ref'], + meta=meta, + image_dict=image_dict, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + num_pairs_per_subject=args.num_pairs, + same_sex_ratio=args.same_sex_ratio, + seed=args.seed + 1 + ) + + val_path = os.path.join(args.output_dir, f"{args.output_prefix}_val.json") + with open(val_path, 'w') as f: + json.dump(val_tasks, f, indent=2) + print(f"Saved: {val_path}") + + # Test: query from test_query, reference from test_ref (NO OVERLAP!) + print("\nGenerating TEST tasks...") + test_tasks = generate_comparison_tasks( + query_subjects=splits['test_query'], + reference_subjects=splits['test_ref'], + meta=meta, + image_dict=image_dict, + subject_id_col=args.subject_id_col, + sex_col=args.sex_col, + num_pairs_per_subject=args.num_pairs, + same_sex_ratio=args.same_sex_ratio, + seed=args.seed + 2 + ) + + test_path = os.path.join(args.output_dir, f"{args.output_prefix}_test.json") + with open(test_path, 'w') as f: + json.dump(test_tasks, f, indent=2) + print(f"Saved: {test_path}") + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"Train tasks: {len(train_tasks)}") + print(f"Val tasks: {len(val_tasks)}") + print(f"Test tasks: {len(test_tasks)}") + print(f"Total tasks: {len(train_tasks) + len(val_tasks) + len(test_tasks)}") + print("=" * 70) + + # Print sample task + print("\nSample TRAIN task:") + print(json.dumps(train_tasks[0], indent=2)) + + +if __name__ == '__main__': + main() diff --git a/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py b/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py new file mode 100644 index 0000000..3c50cae --- /dev/null +++ b/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py @@ -0,0 +1,301 @@ +""" +Multi-turn Conversation Dataset for Comparison Tasks + +Supports: +- Sex comparison (reference + query) +- Age comparison (reference + query) +- 2-turn conversation format +""" + +import os +import json +import numpy as np +import torch +from torch.utils.data import Dataset + +from monai.data import NibabelReader +from monai.transforms import LoadImage, Randomizable, apply_transform, AddChannel, Compose, Resize, NormalizeIntensity, RandAxisFlip, ToTensor +from monai.utils import MAX_SEED, get_seed + +from utils.utils import to_3tuple + + +class MultiTurnComparisonDataset(Dataset, Randomizable): + """ + Multi-turn conversation dataset for comparison-based tasks + + Format: + Turn 1: User shows reference image + label → Assistant acknowledges + Turn 2: User shows query image + asks question → Assistant predicts + + Args: + json_path: Path to JSON file with comparison tasks + processor: HuggingFace processor + img_size: Image size [H, W, D] + mode: 'train' or 'eval' + """ + + def __init__(self, + json_path=None, + processor=None, + img_size=None, + mode='train'): + + self.json_path = json_path + self.processor = processor + self.tokenizer = processor.tokenizer if processor is not None else None + self.img_size = img_size + self.mode = mode + + # Load JSON data + with open(json_path, 'r') as f: + self.tasks = json.load(f) + + print(f"Loaded {len(self.tasks)} tasks from {json_path}") + + # Define image transform + self.image_transform = self.define_augmentation(mode=mode) + self.image_loader = LoadImage(reader=None, image_only=True, dtype=np.float32) + + self.set_random_state(seed=get_seed()) + self._seed = 0 + + + def define_augmentation(self, mode='train'): + """Define image augmentation""" + img_size = to_3tuple(self.img_size) + if mode == 'train': + transform = Compose([ + AddChannel(), + Resize(img_size), + RandAxisFlip(prob=0.5), + NormalizeIntensity() + ]) + elif mode == 'eval': + transform = Compose([ + AddChannel(), + Resize(img_size), + NormalizeIntensity() + ]) + return transform + + + def randomize(self, data=None) -> None: + self._seed = self.R.randint(MAX_SEED, dtype='uint32') + + + def __transform_image__(self, image_file): + """Load and transform a single image""" + image = self.image_loader(image_file) + if self.image_transform is not None: + if isinstance(self.image_transform, Randomizable): + self.image_transform.set_random_state(seed=self._seed) + image = apply_transform(self.image_transform, image, map_items=False) + image = torch.tensor(image) + return image + + + def __build_conversation_text__(self, task): + """ + Build multi-turn conversation text from task + + Returns: + full_text: Complete conversation in LLaVA-NeXT-Interleave format + answer_start_pos: Position where assistant's final answer starts (for label masking) + """ + + conversations = task['conversations'] + + # Build conversation following Qwen2 format + full_text = "" + + for i, turn in enumerate(conversations): + role = turn['role'] + content_list = turn['content'] + + if role == 'user': + full_text += "<|im_start|>user\n" + elif role == 'assistant': + full_text += "<|im_start|>assistant\n" + + # Process content (text + image tokens) + for content_item in content_list: + if content_item['type'] == 'text': + full_text += content_item['text'] + elif content_item['type'] == 'image': + full_text += "" + + full_text += "<|im_end|>\n" + + return full_text + + def __preprocess_as_hf__(self, images, full_text): + """ + Tokenize multi-turn conversation and apply instruction masking + + Args: + images: List of [ref_image_tensor, query_image_tensor] + full_text: Complete conversation text + + Returns: + Dictionary with pixel_values, input_ids, attention_mask, labels + """ + inputs = {} + inputs['pixel_values'] = {} + inputs['input_ids'] = {} + inputs['attention_mask'] = {} + inputs['labels'] = {} + + # ========== 핵심 수정! ========== + # 두 이미지를 개별적으로 batch 차원 추가한 후 합치기 + # ref_image: [C, H, W, D] → [1, C, H, W, D] + # query_image: [C, H, W, D] → [1, C, H, W, D] + # 합치기: [2, C, H, W, D] → 이제 PatchEmbed가 batch=2로 처리 + + processed_images = [] + for img in images: + # Add batch dimension to each image + processed_images.append(img.unsqueeze(0)) # [1, C, H, W, D] + + # Concatenate along batch dimension + batched_images = torch.cat(processed_images, dim=0) # [2, C, H, W, D] + + inputs['pixel_values']['T1'] = batched_images + # ================================== + + # Tokenize full conversation + full_encoding = self.tokenizer( + full_text, + add_special_tokens=True, + padding='max_length', + max_length=512, + truncation=True, + return_tensors='pt' + ) + + input_ids = full_encoding['input_ids'].squeeze(0) + attention_mask = full_encoding['attention_mask'].squeeze(0) + + # Initialize labels + labels = input_ids.clone() + labels[attention_mask == 0] = -100 # Mask padding + + # Apply instruction masking: mask everything except the LAST assistant's response + # We want to train only on the final answer, not on the intermediate "Understood" response + + # Find all assistant tokens + assistant_pattern = "<|im_start|>assistant\n" + assistant_tokens = self.tokenizer.encode(assistant_pattern, add_special_tokens=False) + assistant_tensor = torch.tensor(assistant_tokens, device=input_ids.device) + + assistant_positions = [] + for i in range(len(input_ids) - len(assistant_tokens) + 1): + if torch.equal(input_ids[i:i+len(assistant_tokens)], assistant_tensor): + assistant_positions.append(i + len(assistant_tokens)) + + if len(assistant_positions) >= 2: + # Mask everything before the LAST assistant response + last_assistant_pos = assistant_positions[-1] + labels[:last_assistant_pos] = -100 + elif len(assistant_positions) == 1: + # Only one assistant response (shouldn't happen in 2-turn, but handle it) + labels[:assistant_positions[0]] = -100 + + inputs['input_ids']['T1'] = input_ids + inputs['attention_mask']['T1'] = attention_mask + inputs['labels']['T1'] = labels + + return inputs + + + def __len__(self) -> int: + return len(self.tasks) + + + def __getitem__(self, index: int): + """ + Returns a multi-turn comparison sample + + Returns: + Dictionary with: + - pixel_values: Tensor [num_images, C, H, W, D] (dynamically determined from JSON) + - input_ids, attention_mask, labels: Tokenized multi-turn conversation + - modality: 'Comparison' + """ + + task = self.tasks[index] + + # Load ALL images dynamically (supports N references + 1 query) + # JSON format: images = [ref1, ref2, ..., refN, query] + images = [] + for img_info in task['images']: + img_path = img_info['path'] + img_tensor = self.__transform_image__(img_path) + images.append(img_tensor) + + # Build conversation text + full_text = self.__build_conversation_text__(task) + + # Preprocess for model + inputs = self.__preprocess_as_hf__(images=images, full_text=full_text) + # Don't add 'modality' key - trainer extracts modality from dict keys (T1, rsfMRI, etc.) + + return inputs + + +class ComparisonDataModule: + """ + Data module for comparison tasks (train/val/test splits) + """ + + def __init__(self, + train_json=None, + val_json=None, + test_json=None, + processor=None, + img_size=None): + + self.train_json = train_json + self.val_json = val_json + self.test_json = test_json + self.processor = processor + self.img_size = img_size + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + + self.setup() + + + def setup(self): + """Create train/val/test datasets""" + + if self.train_json is not None: + self.train_dataset = MultiTurnComparisonDataset( + json_path=self.train_json, + processor=self.processor, + img_size=self.img_size, + mode='train' + ) + print(f"Train: {len(self.train_dataset)} tasks") + + if self.val_json is not None: + self.val_dataset = MultiTurnComparisonDataset( + json_path=self.val_json, + processor=self.processor, + img_size=self.img_size, + mode='eval' + ) + print(f"Val: {len(self.val_dataset)} tasks") + + if self.test_json is not None: + self.test_dataset = MultiTurnComparisonDataset( + json_path=self.test_json, + processor=self.processor, + img_size=self.img_size, + mode='eval' + ) + print(f"Test: {len(self.test_dataset)} tasks") + + return self.train_dataset, self.val_dataset, self.test_dataset diff --git a/BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py b/BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py new file mode 100644 index 0000000..0c46fcd --- /dev/null +++ b/BLIP_MRI/project/main_BLLaVaNextInterleave_comparison_hf_joint_T1.py @@ -0,0 +1,257 @@ +""" +Main training script for Multi-turn Comparison Tasks with LLaVA-NeXT-Interleave + +Supports: +- Sex comparison (reference + query → predict sex) +- Numerical values comparison (reference + query → predict age, bmi, ...) +- 2-turn conversation format +""" + +import datetime +import hashlib +from omegaconf import OmegaConf +from omegaconf import ListConfig + +import torch +import transformers +from transformers import Trainer, TrainingArguments +from utils.Trainer_LLaVaNextInterleave_comparison import CustomTrainer +from utils.Trainer_LLaVaNextInterleave_comparison import compute_metrics_with_tokenizer, preprocess_logits_for_metrics + +from utils.data import CustomDataCollatorWithPadding +from dataset.dataset_T1_LLaVaNextInterleave_comparison import ComparisonDataModule + +import os +import wandb + +import warnings +warnings.filterwarnings('ignore') + +def __main__(): + ### setting huggingface verbose + transformers.logging.set_verbosity_info() + + ### make experiment ID + time_hash = datetime.datetime.now().time() + hash_key = hashlib.sha1(str(time_hash).encode()).hexdigest()[:6] + + config = OmegaConf.load("./config/Brain_LLaVa_train_Deepspeed_joint_multiturn_comparison.yaml") + + ### setting logger + wandb.login(key=config.wandb.API_KEY) + os.environ['WANDB_PROJECT'] = "BLIP_sMRI_LLaVA_Next_Interleave_MultiTurn_Comparison" + os.environ["WANDB_RUN_ID"] = f'{hash_key}' + + ### setting seed + transformers.set_seed(config.seed) + + ### setting processor and tokenizer for LLaVA-NeXT-Interleave + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf", trust_remote_code=True) + tokenizer = processor.tokenizer + + ### Load comparison task datasets from multiple sources + from utils.data import InterleaveDataset + + train_datasets = [] + eval_datasets = [] + test_datasets = [] + + + # Support multiple JSON files (e.g., ABCD, UKB, etc.) + if isinstance(config.dataset.train_json, (list, ListConfig)): + train_json_list = list(config.dataset.train_json) + val_json_list = list(config.dataset.val_json) + test_json_list = list(config.dataset.test_json) + else: + train_json_list = [config.dataset.train_json] + val_json_list = [config.dataset.val_json] + test_json_list = [config.dataset.test_json] + + for train_json, val_json, test_json in zip(train_json_list, val_json_list, test_json_list): + data_module = ComparisonDataModule( + train_json=train_json, + val_json=val_json, + test_json=test_json, + processor=processor, + img_size=config.dataset.img_size + ) + + if data_module.train_dataset is not None: + train_datasets.append(data_module.train_dataset) + if data_module.val_dataset is not None: + eval_datasets.append(data_module.val_dataset) + if data_module.test_dataset is not None: + test_datasets.append(data_module.test_dataset) + + # Concatenate all datasets + if len(train_datasets) > 1: + train_dataset = InterleaveDataset(train_datasets, shuffle=True, seed=config.seed) + elif len(train_datasets) == 1: + train_dataset = train_datasets[0] + else: + train_dataset = None + + if len(eval_datasets) > 1: + eval_dataset = InterleaveDataset(eval_datasets, shuffle=False, seed=config.seed) + elif len(eval_datasets) == 1: + eval_dataset = eval_datasets[0] + else: + eval_dataset = None + + if len(test_datasets) > 1: + test_dataset = InterleaveDataset(test_datasets, shuffle=False, seed=config.seed) + elif len(test_datasets) == 1: + test_dataset = test_datasets[0] + else: + test_dataset = None + + #### setting model - LLaVA-NeXT-Interleave + # from model.Bblip_t5 import PatchEmbed + from model.Bblip_t5_interleave import PatchEmbedInterleave + from transformers import LlavaForConditionalGeneration + + # Load LLaVA-NeXT-Interleave model (Qwen-based) + model = LlavaForConditionalGeneration.from_pretrained( + "llava-hf/llava-interleave-qwen-0.5b-hf", + # torch_dtype=torch.bfloat16, + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # attn_implementation="eager" + ) + + patch_embed = PatchEmbedInterleave( + T1_size=config.dataset.img_size, + T1_patch_size=config.model.patch_size, + in_chans=1, + embed_dim=int(model.vision_tower.vision_model.embeddings.patch_embedding.out_channels) + ) + + # # Replace vision encoder's patch embedding layer for 3D brain MRI + # patch_embed = PatchEmbedInterleave( + # T1_size=config.dataset.img_size, + # T1_patch_size=config.model.patch_size, + # rsfMRI_size=[96, 96, 96, 24], # Placeholder (not used) + # rsfMRI_patch_size=[16, 16, 16, 3], # Placeholder (not used) + # in_chans=1, + # embed_dim=int(model.vision_tower.vision_model.embeddings.patch_embedding.out_channels)) + + setattr(model.vision_tower.vision_model, "embeddings", patch_embed) + + # Freeze vision encoder except embeddings + for name, param in model.vision_tower.vision_model.named_parameters(): + if 'encoder' in name: + param.requires_grad = False + if 'pre_layernorm' in name: + param.requires_grad = False + if 'post_layernorm' in name: + param.requires_grad = False + if 'embeddings' in name: + param.requires_grad = True + + # Freeze multi-modal projector + for name, param in model.named_parameters(): + if 'multi_modal_projector' in name: + param.requires_grad = False + + # Freeze language model + for name, param in model.named_parameters(): + if 'model.layers' in name: # Qwen2 uses model.layers + param.requires_grad = False + if 'lm_head' in name: + param.requires_grad = False + + # set gradient checkpointing + model.gradient_checkpointing_enable() + + training_args = TrainingArguments( + # basic settings + output_dir=f'./hf_results/{os.environ["WANDB_RUN_ID"]}', + do_train=True, + do_eval=True, + remove_unused_columns=False, + # training + num_train_epochs=config.trainer.max_epochs, + learning_rate=config.trainer.learning_rate, + warmup_steps=config.trainer.warmup_steps, + weight_decay=config.trainer.weight_decay, + per_device_train_batch_size=config.trainer.per_device_batch_size, + per_device_eval_batch_size=config.trainer.per_device_batch_size, + gradient_accumulation_steps=config.trainer.gradient_accumulation_steps, + # # arguments for reducing memory + # bf16=True, + # bf16_full_eval=True, + # for evaluation and loggings + report_to = 'wandb', + logging_dir=f'./hf_logs/{os.environ["WANDB_RUN_ID"]}', + logging_steps=config.trainer.logging_steps, + eval_strategy="steps", + eval_steps=1000, + eval_accumulation_steps=1, + save_steps=1000, + disable_tqdm=False, + # checkpoint saving + save_strategy="steps", + save_total_limit=3, + load_best_model_at_end=True + ) + + # Determine task type for metrics + # Support both single target and multi-target (list) + target_col = config.dataset.get('target_col', None) + task_type = config.dataset.get('task_type', 'categorical') + + if target_col: + # target_col can be string or list + if isinstance(target_col, list): + targets = target_col + print(f"[INFO] Multi-task mode with targets: {targets}") + else: + targets = [target_col] + print(f"[INFO] Single-task mode with target: {target_col}") + else: + print(f"[WARN] No target_col or task_type specified") + + + # Use existing compute_metrics - it already handles long reasoning text! + # The key difference is that multi-turn generates longer responses, + # but the extraction logic (regex search for 'male'/'female' or numbers) is the same + trainer = CustomTrainer( + args=training_args, + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + compute_metrics=compute_metrics_with_tokenizer(tokenizer=tokenizer, targets=targets), + data_collator = CustomDataCollatorWithPadding( + tokenizer=tokenizer, + padding=True, + max_length=512 + ), + model_optimization_type = 'joint', + ) + + # training + trainer.train() + + # test + if test_dataset is not None: + trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix='test') + + +if __name__ == '__main__': + __main__() + + +""" +##TODO + +1. Generate JSON files: + python generate_json_general_comparison_split.py + +2. Update config with correct paths + +3. Train: + python main_MultiTurn_Comparison.py +""" diff --git a/BLIP_MRI/project/model/Bblip_t5_interleave.py b/BLIP_MRI/project/model/Bblip_t5_interleave.py new file mode 100644 index 0000000..9363aae --- /dev/null +++ b/BLIP_MRI/project/model/Bblip_t5_interleave.py @@ -0,0 +1,153 @@ +""" +Multi-Image PatchEmbed for LLaVA-NeXT-Interleave style processing + +This module supports processing multiple 3D brain MRI images independently, +similar to how LLaVA-NeXT-Interleave handles multiple 2D images. + +Key differences from Bblip_t5.py: +- Supports batch dimension containing multiple images (e.g., reference + query) +- Each image is processed independently through the same patch embedding layer +- Returns concatenated features that can be interleaved in the language model +""" + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ + + +class PatchEmbedInterleave(nn.Module): + """ + Image to Patch Embedding with Multi-Image Support + + Supports processing multiple 3D brain MRI images where the batch dimension + represents individual images that should be processed independently. + + Args: + T1_size: Size of T1 image [D, H, W] + T1_patch_size: Patch size for T1 [pD, pH, pW] + in_chans: Number of input channels (default: 1) + embed_dim: Embedding dimension + dtype: Data type for parameters + """ + + def __init__(self, + T1_size=[120, 120, 120], + T1_patch_size=[10, 10, 10], + in_chans=1, + embed_dim=1152, # SigLIP hidden_size for llava-interleave-qwen-0.5b-hf + dtype=torch.float32): + super().__init__() + self.embed_dim = embed_dim + + # Patchifying layer for T1 images + T1_num_patches = (T1_size[0] // T1_patch_size[0]) * \ + (T1_size[1] // T1_patch_size[1]) * \ + (T1_size[2] // T1_patch_size[2]) + + self.T1_grid_size = ( + T1_size[0] // T1_patch_size[0], + T1_size[1] // T1_patch_size[1], + T1_size[2] // T1_patch_size[2] + ) + self.T1_size = T1_size + self.T1_patch_size = T1_patch_size + self.T1_num_patches = T1_num_patches + + # Convolutional projection layer + self.T1_proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=T1_patch_size, + stride=T1_patch_size, + dtype=dtype + ) + + # Positional embeddings + self.T1_positional_embeddings = nn.Parameter( + torch.zeros(1, T1_num_patches, embed_dim) + ) + trunc_normal_(self.T1_positional_embeddings, std=.02) + + + def forward_embeddings(self, x): + """ + Process 3D brain MRI through patch embedding + + Args: + x: Input tensor of shape [B, C, D, H, W] + B can be batch_size OR batch_size * num_images + Each image along the B dimension is processed independently + + Returns: + Patch embeddings of shape [B, num_patches, embed_dim] + """ + if len(x.shape) == 5: + B, C, D, H, W = x.shape + + # Validate input size + assert D == self.T1_size[0] and H == self.T1_size[1] and W == self.T1_size[2], \ + f"Input image size ({D}*{H}*{W}) doesn't match model ({self.T1_size[0]}*{self.T1_size[1]}*{self.T1_size[2]})." + + # Apply convolutional projection + # Input: [B, C, D, H, W] + # Output: [B, embed_dim, grid_D, grid_H, grid_W] + x = self.T1_proj(x) + + # Flatten spatial dimensions and transpose + # [B, embed_dim, grid_D, grid_H, grid_W] -> [B, embed_dim, num_patches] + x = x.flatten(2) + + # [B, embed_dim, num_patches] -> [B, num_patches, embed_dim] + x = x.transpose(1, 2) + + # Add positional embeddings + # Positional embeddings are shared across all images in the batch + x = x + self.T1_positional_embeddings + + return x # [B, num_patches, embed_dim] + else: + raise ValueError(f"Expected 5D tensor [B, C, D, H, W], got shape {x.shape}") + + + def forward(self, x, interpolate_pos_encoding=False): + """ + Forward pass supporting both single and multi-image inputs + + Args: + x: Input tensor, can be: + - [B, C, D, H, W]: Standard batch of images + - [1, num_images, C, D, H, W]: Batch with multiple images per sample + interpolate_pos_encoding: Not used, for API compatibility + + Returns: + Patch embeddings [B*num_images, num_patches, embed_dim] + + Note: + Unlike the original PatchEmbed, this version does NOT concatenate + embeddings from multiple images along the batch dimension. + Instead, it keeps them separate so they can be interleaved properly + with text tokens in the language model. + """ + if isinstance(x, dict): + # Handle dict input (multi-modality case) + # For multi-turn comparison, we only use 'T1' modality + raise NotImplementedError( + "Multi-modality dict input not supported in PatchEmbedInterleave. " + "Use separate forward passes for each modality." + ) + else: + # Check if input has extra batch dimension from data collator + if len(x.shape) == 6: + # Shape: [batch_size, num_images, C, D, H, W] + # Reshape to: [batch_size * num_images, C, D, H, W] + batch_size, num_images, C, D, H, W = x.shape + x = x.reshape(batch_size * num_images, C, D, H, W) + + # Process all images in the batch + outputs = self.forward_embeddings(x) + return outputs + + + def get_num_patches(self): + """Return the number of patches per image""" + return self.T1_num_patches diff --git a/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py b/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py new file mode 100644 index 0000000..5d279d2 --- /dev/null +++ b/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py @@ -0,0 +1,893 @@ +import os +import json +import numpy as np + +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +import datasets + +from transformers import Trainer +from transformers.trainer_utils import has_length, seed_worker +from transformers.training_args import ParallelMode +from transformers.utils import ( + is_datasets_available, + is_sagemaker_mp_enabled, + +) +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + LengthGroupedSampler, + SequentialDistributedSampler, + nested_detach, + IterableDatasetShard, + +) + +from sklearn.metrics import balanced_accuracy_score, f1_score +from dataclasses import dataclass + + +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + + +def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + logits = logits[0] + + pred_ids = torch.argmax(logits, dim=-1) + return pred_ids + +# @torch.no_grad() +def compute_metrics_with_tokenizer(tokenizer, targets): + """ + Automatically compute metrics based on target types. + Categorical: sex -> accuracy, f1 + Numerical: age, bmi, glucose -> MAE, RMSE + """ + import re + from sklearn.metrics import mean_absolute_error, mean_squared_error + + def compute_metrics(eval_preds): + predictions, labels = eval_preds + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + + decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + metrics = {} + total_samples = len(decoded_preds) + + # Check if single-task or multi-task + is_single_task = len(targets) == 1 + + # Determine task type + task_name = targets[0] if is_single_task else None + is_sex_task = task_name and 'sex' in task_name.lower() + + # Check if numerical (simple heuristic: common numerical task names) + numerical_keywords = ['age', 'bmi', 'glucose', 'weight', 'height', 'score'] + is_numerical_task = task_name and any(kw in task_name.lower() for kw in numerical_keywords) + + if is_single_task and is_sex_task: + # Sex classification task + pred_genders = [] + true_genders = [] + + for pred in decoded_preds: + pred_clean = pred.lower().strip() + if re.search(r'\bfemale\b', pred_clean): + pred_genders.append(1) + elif re.search(r'\bmale\b', pred_clean): + pred_genders.append(0) + else: + pred_genders.append(-1) + + for label in decoded_labels: + label_clean = label.lower().strip() + if re.search(r'\bfemale\b', label_clean): + true_genders.append(1) + elif re.search(r'\bmale\b', label_clean): + true_genders.append(0) + else: + true_genders.append(-1) + + # Valid pairs (only for metrics on valid predictions) + valid_pairs = [(p, t) for p, t in zip(pred_genders, true_genders) if p != -1 and t != -1] + + if valid_pairs: + valid_preds, valid_trues = zip(*valid_pairs) + valid_accuracy = balanced_accuracy_score(valid_trues, valid_preds) + valid_f1 = f1_score(valid_trues, valid_preds, average='macro') + else: + valid_accuracy = 0.0 + valid_f1 = 0.0 + + # Overall metrics (treat invalid as wrong answer) + overall_preds = [] + overall_trues = [] + + for p, t in zip(pred_genders, true_genders): + if t != -1: # Only when ground truth is valid + overall_trues.append(t) + if p == -1: + # Treat invalid as wrong (flip the answer) + overall_preds.append(1 - t) + else: + overall_preds.append(p) + + if overall_preds: + overall_accuracy = balanced_accuracy_score(overall_trues, overall_preds) + overall_f1 = f1_score(overall_trues, overall_preds, average='macro') + else: + overall_accuracy = 0.0 + overall_f1 = 0.0 + + invalid_predictions = pred_genders.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + metrics = { + 'accuracy': valid_accuracy, + 'f1': valid_f1, + 'overall_accuracy': overall_accuracy, + 'overall_f1': overall_f1, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + + elif is_single_task and is_numerical_task: + # Single numerical task (e.g., age, bmi, glucose) + task_name = targets[0] + pred_values = [] + true_values = [] + + for pred in decoded_preds: + pred_clean = pred.strip() + # Extract first number + match = re.search(r'(\d+\.?\d*)', pred_clean) + if match: + pred_values.append(float(match.group(1))) + else: + pred_values.append(-1) + + for label in decoded_labels: + label_clean = label.strip() + match = re.search(r'(\d+\.?\d*)', label_clean) + if match: + true_values.append(float(match.group(1))) + else: + true_values.append(-1) + + # Valid pairs + valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + if valid_pairs: + valid_preds, valid_trues = zip(*valid_pairs) + mae = mean_absolute_error(valid_trues, valid_preds) + rmse = np.sqrt(mean_squared_error(valid_trues, valid_preds)) + else: + mae = 0.0 + rmse = 0.0 + + invalid_predictions = pred_values.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + metrics = { + f'{task_name}_mae': mae, + f'{task_name}_rmse': rmse, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + + elif is_single_task: + # Other categorical tasks (not sex) + # Extract unique labels from ground truth + print(f"[INFO] Generic categorical task detected: {task_name}") + print(f"[INFO] Extracting labels from ground truth...") + + # First pass: collect all possible labels from ground truth + all_labels = set() + for label in decoded_labels: + label_clean = label.lower().strip() + # Try to extract label after common patterns + # Pattern 1: "appears to be X" + match = re.search(r'appears to be\s+(\w+)', label_clean) + if match: + all_labels.add(match.group(1)) + # Pattern 2: "is X" + elif re.search(r'\bis\s+(\w+)', label_clean): + match = re.search(r'\bis\s+(\w+)', label_clean) + all_labels.add(match.group(1)) + # Pattern 3: just the label itself (e.g., "control", "patient") + else: + words = label_clean.split() + if len(words) > 0: + all_labels.add(words[-1]) # Take last word as label + + # Create label to idx mapping + label_to_idx = {label: idx for idx, label in enumerate(sorted(all_labels))} + print(f"[INFO] Detected labels: {label_to_idx}") + + pred_values = [] + true_values = [] + + for pred in decoded_preds: + pred_clean = pred.lower().strip() + found = False + for label_text in label_to_idx.keys(): + if re.search(rf'\b{label_text}\b', pred_clean): + pred_values.append(label_to_idx[label_text]) + found = True + break + if not found: + pred_values.append(-1) + + for label in decoded_labels: + label_clean = label.lower().strip() + found = False + for label_text in label_to_idx.keys(): + if re.search(rf'\b{label_text}\b', label_clean): + true_values.append(label_to_idx[label_text]) + found = True + break + if not found: + true_values.append(-1) + + # Valid pairs + valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + if valid_pairs: + valid_preds, valid_trues = zip(*valid_pairs) + accuracy = balanced_accuracy_score(valid_trues, valid_preds) + f1 = f1_score(valid_trues, valid_preds, average='macro') + else: + accuracy = 0.0 + f1 = 0.0 + + invalid_predictions = pred_values.count(-1) + response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + metrics = { + 'accuracy': accuracy, + 'f1': f1, + 'response_rate': response_rate, + 'valid_samples': len(valid_pairs), + 'total_samples': total_samples, + 'invalid_predictions': invalid_predictions + } + + # else: + # # Multi-task + # for task_name in targets: + # if task_name in categorical_tasks: + # # Categorical task + # pattern = rf'{task_name}:\s*(\w+)' + # pred_values = [] + # true_values = [] + + # for pred in decoded_preds: + # pred_clean = pred.lower().strip() + # match = re.search(pattern, pred_clean) + # if match: + # label_text = match.group(1) + # if label_text in categorical_tasks[task_name]: + # pred_values.append(categorical_tasks[task_name][label_text]) + # else: + # pred_values.append(-1) + # else: + # pred_values.append(-1) + + # for label in decoded_labels: + # label_clean = label.lower().strip() + # match = re.search(pattern, label_clean) + # if match: + # label_text = match.group(1) + # if label_text in categorical_tasks[task_name]: + # true_values.append(categorical_tasks[task_name][label_text]) + # else: + # true_values.append(-1) + # else: + # true_values.append(-1) + + # valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + # if valid_pairs: + # valid_preds, valid_trues = zip(*valid_pairs) + # accuracy = balanced_accuracy_score(valid_trues, valid_preds) + # f1 = f1_score(valid_trues, valid_preds, average='macro') + # else: + # accuracy = 0.0 + # f1 = 0.0 + + # metrics[f'{task_name}_accuracy'] = accuracy + # metrics[f'{task_name}_f1'] = f1 + + # else: + # # Numerical task + # pattern = rf'{task_name}:\s*(\d+\.?\d*)' + # pred_values = [] + # true_values = [] + + # for pred in decoded_preds: + # pred_clean = pred.strip() + # match = re.search(pattern, pred_clean) + # if match: + # pred_values.append(float(match.group(1))) + # else: + # pred_values.append(-1) + + # for label in decoded_labels: + # label_clean = label.strip() + # match = re.search(pattern, label_clean) + # if match: + # true_values.append(float(match.group(1))) + # else: + # true_values.append(-1) + + # valid_pairs = [(p, t) for p, t in zip(pred_values, true_values) if p != -1 and t != -1] + + # if valid_pairs: + # valid_preds, valid_trues = zip(*valid_pairs) + # mae = mean_absolute_error(valid_trues, valid_preds) + # rmse = np.sqrt(mean_squared_error(valid_trues, valid_preds)) + # else: + # mae = 0.0 + # rmse = 0.0 + + # metrics[f'{task_name}_mae'] = mae + # metrics[f'{task_name}_rmse'] = rmse + + # # Overall response rate + # all_pred_values = [] + # for pred in decoded_preds: + # valid = True + # for task_name in targets: + # if task_name in categorical_tasks: + # pattern = rf'{task_name}:\s*(\w+)' + # else: + # pattern = rf'{task_name}:\s*(\d+\.?\d*)' + # if not re.search(pattern, pred.lower()): + # valid = False + # break + # all_pred_values.append(1 if valid else -1) + + # invalid_predictions = all_pred_values.count(-1) + # response_rate = (total_samples - invalid_predictions) / total_samples if total_samples > 0 else 0 + + # metrics['response_rate'] = response_rate + # metrics['total_samples'] = total_samples + # metrics['invalid_predictions'] = invalid_predictions + + return metrics + + return compute_metrics + + +class CustomTrainer(Trainer): + """ + Modified based on https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L294 + """ + + def __init__(self, model_optimization_type='sequential', *args, **kwargs): + # Set static graph for DDP + super().__init__(*args, **kwargs) + self._static_graph_set = False + self.model_optimization_type= model_optimization_type + + + def _ensure_set_static_graph(self, model): + if not self._static_graph_set and self.is_in_train: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model._set_static_graph() + self._static_graph_set = True + + + def repack_inputs_except_for_pixel_values(self, inputs, modalities): + """ + inputs = + { + 'T1': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + 'rsfMRI': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + } + + outputs = + { + 'pixel_values': { + 'T1': torch.tensor([torch.tensor]), + 'rsfMRI':torch.tensor([torch.tensor]), + } + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]), + + } + """ + assert len(modalities) > 1 + + outputs = {} + outputs['pixel_values'] = {} + outputs['input_ids'] = [] + outputs['attention_mask'] = [] + outputs['labels'] = [] + + for modality in modalities: + modality_data = inputs[modality] + #print(modality_data) + outputs['pixel_values'][modality] = modality_data['pixel_values'] + outputs['input_ids'].append(modality_data['input_ids']) + outputs['attention_mask'].append(modality_data['attention_mask']) + outputs['labels'].append(modality_data['labels']) + + outputs['input_ids'] = torch.cat(outputs['input_ids'], dim=0) + outputs['attention_mask'] = torch.cat(outputs['attention_mask'], dim=0) + outputs['labels'] = torch.cat(outputs['labels'], dim=0) + + return outputs + + + def _compute_modality_loss(self, model, inputs, labels=None): + """Helper function to compute loss for a single modality""" + outputs = model(**inputs) + + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + unwrapped_model = self.accelerator.unwrap_model(model) + model_name = unwrapped_model.base_model.model._get_name() if _is_peft_model(unwrapped_model) else unwrapped_model._get_name() + loss = self.label_smoother(outputs, labels, shift_labels=model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError(f"Model did not return loss. Got keys: {','.join(outputs.keys())}") + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return loss, outputs + + + def _compute_dummy_gradient(self, model, active_modality): + """Compute dummy gradient for inactive modality parameters.""" + skip_modality = 'rsfMRI' if active_modality == 'T1' else 'T1' + + # Get embeddings module + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + base_model = model.module + else: + base_model = model + + embeddings = (base_model.vision_tower.vision_model.embeddings + if hasattr(base_model, 'vision_tower') + else base_model.vision_model.embeddings) # vision_tower is for LLaVA + + # Compute dummy loss + dummy_loss = 0. + for name, param in embeddings.named_parameters(): + if skip_modality in name: + dummy_loss += param.sum() * 0. + + return dummy_loss + + + def _compute_loss_with_labels(self, model, inputs): + """Compute loss handling both label_smoother and direct cases.""" + # Extract labels if using label smoother + if self.label_smoother and "labels" in inputs: + labels = inputs.pop("labels") + return self._compute_modality_loss(model, inputs, labels) + + outputs = model(**inputs) + + # Extract loss from various output formats + if hasattr(outputs, 'loss'): + loss = outputs.loss + elif isinstance(outputs, dict) and 'loss' in outputs: + loss = outputs['loss'] + elif isinstance(outputs, (tuple, list)) and len(outputs) > 0: + loss = outputs[0] + else: + raise ValueError(f"Model did not return a loss. Output type: {type(outputs)}") + + return loss, outputs + + + def compute_loss(self, model, inputs, return_outputs=False): + #TODO + #현재 방식의 코드에서는 태생적으로 순차적으로 두개의 모달리티로부터 각각 loss를 얻어서 합한 loss로 최적화할 수가 없다. + #왜냐하면 한개의 모달리티로부터 Loss를 얻기 위해서는 patch layer를 제외한 나머지 layer들을 전부 거쳐야하는데, 이렇게 하고 나면 거쳐간 layer들을 업데이트하지 않은 상태에서 두번째 모달리티의 데이터가 이런 layer들을 거치게 되면서 backward()에서 에러가 발생한다. + #그런데 흥미로운 점은 x-instruct-BLIP 페이퍼에서는 다양한 모달리티로부터 얻은 Loss들을 joint optimization하지 않아도 multi-modal network를 학습할 수 있음을 보였다. + #다만, OneLLM은 애초에 라우팅하는 것을 특장점으로 삼았기 때문에 joint optimization을 한다 + # joint optimization을 위해서는 BLIP2의 원래 코드를 짜고, 그 코드 위에다가 weight를 얹는 방식으로 진행해야할 것 같다. + + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + + inputs = + { + 'T1': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + 'rsfMRI': + { + 'pixel_values': torch.tensor([torch.tensor]), + 'input_ids': torch.tensor([torch.tensor]), + 'attention_mask': torch.tensor([torch.tensor]), + 'labels': torch.tensor([torch.tensor]) + }, + + } + + """ + self._ensure_set_static_graph(model) + total_loss = 0. + outputs = None + modalities = list(inputs.keys()) + + if len(modalities) == 1: + # Single modality: add dummy gradient for stability + modality = modalities[0] + inputs_single = inputs[modality].copy() + + # Dummy loss for unused modality parameters + dummy_loss = self._compute_dummy_gradient(model, modality) + + # Compute actual loss + loss, outputs = self._compute_loss_with_labels(model, inputs_single) + total_loss = dummy_loss + loss + + else: # len(modalities) >= 2 + # Multiple modalities: repack and compute + inputs_repacked = self.repack_inputs_except_for_pixel_values(inputs, modalities) + loss, outputs = self._compute_loss_with_labels(model, inputs_repacked) + total_loss = loss + + return (total_loss, outputs) if return_outputs else total_loss + + + def training_step(self, model, inputs): + loss = super().training_step(model, inputs) + + # generation result + if self.state.global_step % 50 == 0 and self.state.global_step > 0: + self.log_generated_result(model, inputs, mode="training") + + # Log gradients at logging steps + modalities = list(inputs.keys()) + if len(modalities) == 1: + if self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0: + grad_norms = {} + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + if modalities[0] in name: + if 'bias' in name: + continue + else: + grad_norms[f"grad/{name}"] = param.grad.norm().item() + + else: + if self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0: + grad_norms = {} + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + if 'bias' in name: + continue + else: + grad_norms[f"grad/{name}"] = param.grad.norm().item() + + # Log to loggers through trainer's log() method + self.log(grad_norms) + + + """ + # Check gradients after backward + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + print(f"{name} grad norm: {param.grad.norm().item()}") + """ + + return loss + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + + modalities = list(inputs.keys()) + if len(modalities) == 1 and modalities[0] in ['T1', 'rsfMRI']: + inputs = inputs[modalities[0]] + elif len(modalities) > 1: + inputs = self.repack_inputs_except_for_pixel_values(inputs, modalities) + + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + if len(modalities) == 1 and modalities[0] in ['T1', 'rsfMRI']: # do we need this logic + wrapped_inputs = {modalities[0]: inputs} + loss, outputs = self.compute_loss(model, wrapped_inputs, return_outputs=True) + else: + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + + if loss is not None: + if isinstance(loss, torch.Tensor): + loss = loss.mean().detach() + else: + loss = torch.tensor(loss) + + if isinstance(outputs, dict): + # LLaVA + logits = outputs.get('logits', None) + if logits is None: + # fallback + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + if len(logits) == 1: + logits = logits[0] + elif hasattr(outputs, 'logits'): + logits = outputs.logits + else: + logits = outputs[1:] if len(outputs) > 1 else None + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = outputs.get('logits', None) + if logits is None: + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + elif hasattr(outputs, 'logits'): + logits = outputs.logits + else: + logits = outputs + + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0 and hasattr(outputs, '__getitem__'): + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + if logits is not None: + logits = nested_detach(logits) + if isinstance(logits, (tuple, list)) and len(logits) == 1: + logits = logits[0] + + # Log generated result during evaluation (first sample of each eval) + if not prediction_loss_only and not hasattr(self, '_eval_generation_logged'): + self._eval_generation_logged = True + self.log_generated_result(model, inputs, mode="evaluation") + + return (loss, logits, labels) + + + def log_generated_result(self, model, inputs, mode="training"): + """ + Log generated result during training or evaluation + + Args: + model: The model to use for generation + inputs: Input dictionary (wrapped or unwrapped) + mode: "training" or "evaluation" + """ + actual_model = model.module if hasattr(model, 'module') else model + + # Only set eval mode for training (already in eval during evaluation) + if mode == "training": + actual_model.eval() + + with torch.no_grad(): + try: + # Handle input format (different for training vs evaluation) + if 'pixel_values' in inputs and 'input_ids' in inputs: + sample_input = inputs + else: + # Still wrapped in modality key (typical for training) + modality_keys = [k for k in inputs.keys() if k in ['T1', 'rsfMRI']] + if modality_keys: + sample_input = inputs[modality_keys[0]] + else: + sample_input = inputs + + # Get first sample from batch + input_ids = sample_input['input_ids'][0] + + # Handle pixel_values (supports both single-image and multi-image) + pixel_values = sample_input['pixel_values'] + if len(pixel_values.shape) == 6: + # Multi-image: [batch, num_images, C, D, H, W] -> take first batch + pixel_values_sample = pixel_values[0:1] # [1, num_images, C, D, H, W] + elif len(pixel_values.shape) == 5: + # Single-image: [batch, C, D, H, W] -> take first batch + pixel_values_sample = pixel_values[0:1] # [1, C, D, H, W] + else: + print(f"[WARN] Unexpected pixel_values shape: {pixel_values.shape}") + return + + # Search for LAST assistant token (multi-turn: we want to generate the final answer) + # Conversation structure: + # Turn 1: user → assistant (acknowledgment) + # Turn 2: user → assistant (answer) ← We generate this! + assistant_variants = ["<|im_start|>assistant\n", "<|im_start|>assistant"] + assistant_positions = [] + + for variant in assistant_variants: + assistant_tokens = self.tokenizer.encode(variant, add_special_tokens=False) + for i in range(len(input_ids) - len(assistant_tokens)): + if torch.equal(input_ids[i:i+len(assistant_tokens)], + torch.tensor(assistant_tokens, device=input_ids.device)): + assistant_positions.append(i + len(assistant_tokens)) + + if len(assistant_positions) == 0: + print(f"[WARN] Assistant token not found in {mode} input") + return + + # Use LAST assistant position (for multi-turn, this is the final answer) + last_assistant_pos = assistant_positions[-1] + prompt_ids = input_ids[:last_assistant_pos].unsqueeze(0) + + # Generate + generated_ids = actual_model.generate( + pixel_values=pixel_values_sample, # Use prepared pixel_values + input_ids=prompt_ids, + max_new_tokens=250, + do_sample=False, + temperature=0.1, + pad_token_id=self.tokenizer.pad_token_id, + ) + + generated_only = generated_ids[0][len(prompt_ids[0]):] + generated_text = self.tokenizer.decode(generated_only, skip_special_tokens=True) + + # Build result dictionary + result = { + "type": mode, + "step": self.state.global_step, + "epoch": float(self.state.epoch) if hasattr(self.state, 'epoch') else 0, + "generated_text": generated_text, + } + + # Add ground truth for evaluation mode + if mode == "evaluation": + labels = sample_input.get('labels', None) + if labels is not None: + labels_clean = labels[0].clone() + labels_clean[labels_clean == -100] = self.tokenizer.pad_token_id + ground_truth = self.tokenizer.decode(labels_clean, skip_special_tokens=True) + else: + ground_truth = "N/A" + result["ground_truth"] = ground_truth + + # Save to JSON + json_file = "generation_logs.json" + if os.path.exists(json_file): + with open(json_file, 'r') as f: + logs = json.load(f) + else: + logs = [] + + logs.append(result) + + with open(json_file, 'w') as f: + json.dump(logs, f, indent=2, ensure_ascii=False) + + # Print output + prefix = "[TRAIN]" if mode == "training" else "[EVAL]" + if mode == "evaluation": + print("\n" + "="*80) + print(f"{prefix} Step: {self.state.global_step}, Epoch: {result['epoch']}") + print(f"{prefix} Generated: {generated_text}") + print(f"{prefix} Ground Truth: {result.get('ground_truth', 'N/A')}") + print("="*80 + "\n") + else: + print(f"{prefix} Step: {self.state.global_step}") + print(f"{prefix} Generated: {generated_text}") + + except Exception as e: + print(f"[ERROR] {mode.capitalize()} generation failed: {e}") + import traceback + traceback.print_exc() + + # Restore train mode only if we changed it + if mode == "training": + actual_model.train() + + def evaluation_loop(self, *args, **kwargs): + """Override to reset generation flag at start of each evaluation""" + # Reset flag so we log generation once per eval + if hasattr(self, '_eval_generation_logged'): + delattr(self, '_eval_generation_logged') + + return super().evaluation_loop(*args, **kwargs) diff --git a/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh new file mode 100755 index 0000000..0c57451 --- /dev/null +++ b/BLIP_MRI/sample_scripts/BLIP_MRI_LLaVaNextInterleave_T1_DDP_interactive.sh @@ -0,0 +1,21 @@ + +set +x + +cd YOUR_PROJECT_ROOT #TODO: Change to your own scratch space + + +module load python +#module load pytorch/1.13.1 +module load cpe/23.03 + +conda activate BLIP_MRI_llava #TODO: Change to your own conda env + +export LIBRARY_PATH=$LD_LIBRARY_PATH +export TORCH_EXTENSIONS_DIR= #TODO: Change to your own scratch space +export HF_HOME= #TODO: Change to your own scratch space +export TORCH_HOME= #TODO: Change to your own scratch space + + +# LLaVA-NeXT-Interleave Multi-turn Comparison Task (Sex or Age) +torchrun --nnodes 1 --nproc_per_node 1 main_BLLaVaNextInterleave_comparison_hf_joint_T1.py + From b8e93d6438a7f9346c8af19f48b4b2c61e000ed2 Mon Sep 17 00:00:00 2001 From: Sue0515 Date: Mon, 15 Dec 2025 14:54:24 -0800 Subject: [PATCH 3/3] fix: Improve training efficiency and metric accuracy - Add --max_subjects option for dataset size limiting (quick testing) - Fix instruction masking to exclude intermediate assistant responses - Fix numerical metric regex to support negative values (BMI_sds) --- .../generate_json_general_comparison_split.py | 25 +++- ...taset_T1_LLaVaNextInterleave_comparison.py | 130 +++++++++++++++--- .../Trainer_LLaVaNextInterleave_comparison.py | 10 +- 3 files changed, 136 insertions(+), 29 deletions(-) diff --git a/BLIP_MRI/project/data/generate_json_general_comparison_split.py b/BLIP_MRI/project/data/generate_json_general_comparison_split.py index 36da9e1..b1331f8 100644 --- a/BLIP_MRI/project/data/generate_json_general_comparison_split.py +++ b/BLIP_MRI/project/data/generate_json_general_comparison_split.py @@ -18,7 +18,7 @@ from pathlib import Path import random -def load_subjects_and_images(meta_path, img_dir, subject_id_col, target_col, study_sample='ABCD'): +def load_subjects_and_images(meta_path, img_dir, subject_id_col, target_col, study_sample='ABCD', max_subjects=None): """Load metadata and available images""" # Load metadata @@ -54,6 +54,20 @@ def load_subjects_and_images(meta_path, img_dir, subject_id_col, target_col, stu # Filter subjects with both metadata and images meta = meta[meta[subject_id_col].isin(image_dict.keys())].reset_index(drop=True) + # Limit number of subjects if specified + if max_subjects is not None and len(meta) > max_subjects: + print(f"Limiting to {max_subjects} subjects (from {len(meta)})") + # Stratified sampling to maintain class balance + if pd.api.types.is_numeric_dtype(meta[target_col]) and meta[target_col].nunique() <= 10: + # Categorical: stratified by class + samples_per_class = max_subjects // meta[target_col].nunique() + meta = meta.groupby(target_col, group_keys=False).apply( + lambda x: x.sample(min(len(x), samples_per_class), random_state=1234) + ).reset_index(drop=True) + else: + # Numerical or large categorical: random sample + meta = meta.sample(n=max_subjects, random_state=1234).reset_index(drop=True) + # Remap sex values to 0/1 if target_col is 'sex' if 'sex' in target_col.lower(): unique_values = meta[target_col].unique() @@ -739,6 +753,7 @@ def main(): help='Validation set ratio') parser.add_argument('--seed', type=int, default=1234, help='Random seed') + parser.add_argument('--max_subjects', type=int, default=None, help='Maximum number of subjects to use (for quick testing)') args = parser.parse_args() @@ -772,7 +787,7 @@ def main(): # Load data print("\n[Step 1] Loading subjects and images...") meta, image_dict = load_subjects_and_images( - args.meta_path, args.img_dir, args.subject_id_col, args.target_col, args.study_sample + args.meta_path, args.img_dir, args.subject_id_col, args.target_col, args.study_sample, args.max_subjects ) print(f"Loaded {len(meta)} subjects with images") @@ -868,9 +883,9 @@ def main(): with open(test_path, 'w') as f: json.dump(test_tasks, f, indent=2) - print(f"\n✓ Saved: {train_path}") - print(f"✓ Saved: {val_path}") - print(f"✓ Saved: {test_path}") + print(f"\nSaved: {train_path}") + print(f"Saved: {val_path}") + print(f"Saved: {test_path}") # Summary print("\n" + "=" * 70) diff --git a/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py b/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py index 3c50cae..ee33243 100644 --- a/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py +++ b/BLIP_MRI/project/dataset/dataset_T1_LLaVaNextInterleave_comparison.py @@ -128,6 +128,85 @@ def __build_conversation_text__(self, task): full_text += "<|im_end|>\n" return full_text + + + # def __preprocess_as_hf__(self, images, full_text): + # """ + # Tokenize multi-turn conversation and apply instruction masking + + # Args: + # images: List of [ref_image_tensor, query_image_tensor] + # full_text: Complete conversation text + + # Returns: + # Dictionary with pixel_values, input_ids, attention_mask, labels + # """ + # inputs = {} + # inputs['pixel_values'] = {} + # inputs['input_ids'] = {} + # inputs['attention_mask'] = {} + # inputs['labels'] = {} + + # # ========== 핵심 수정! ========== + # # 두 이미지를 개별적으로 batch 차원 추가한 후 합치기 + # # ref_image: [C, H, W, D] → [1, C, H, W, D] + # # query_image: [C, H, W, D] → [1, C, H, W, D] + # # 합치기: [2, C, H, W, D] → 이제 PatchEmbed가 batch=2로 처리 + + # processed_images = [] + # for img in images: + # # Add batch dimension to each image + # processed_images.append(img.unsqueeze(0)) # [1, C, H, W, D] + + # # Concatenate along batch dimension + # batched_images = torch.cat(processed_images, dim=0) # [2, C, H, W, D] + + # inputs['pixel_values']['T1'] = batched_images + # # ================================== + + # # Tokenize full conversation + # full_encoding = self.tokenizer( + # full_text, + # add_special_tokens=True, + # padding='max_length', + # max_length=512, + # truncation=True, + # return_tensors='pt' + # ) + + # input_ids = full_encoding['input_ids'].squeeze(0) + # attention_mask = full_encoding['attention_mask'].squeeze(0) + + # # Initialize labels + # labels = input_ids.clone() + # labels[attention_mask == 0] = -100 # Mask padding + + # # Apply instruction masking: mask everything except the LAST assistant's response + # # We want to train only on the final answer, not on the intermediate "Understood" response + + # # Find all assistant tokens + # assistant_pattern = "<|im_start|>assistant\n" + # assistant_tokens = self.tokenizer.encode(assistant_pattern, add_special_tokens=False) + # assistant_tensor = torch.tensor(assistant_tokens, device=input_ids.device) + + # assistant_positions = [] + # for i in range(len(input_ids) - len(assistant_tokens) + 1): + # if torch.equal(input_ids[i:i+len(assistant_tokens)], assistant_tensor): + # assistant_positions.append(i + len(assistant_tokens)) + + # if len(assistant_positions) >= 2: + # # Mask everything before the LAST assistant response + # last_assistant_pos = assistant_positions[-1] + # labels[:last_assistant_pos] = -100 + # elif len(assistant_positions) == 1: + # # Only one assistant response (shouldn't happen in 2-turn, but handle it) + # labels[:assistant_positions[0]] = -100 + + # inputs['input_ids']['T1'] = input_ids + # inputs['attention_mask']['T1'] = attention_mask + # inputs['labels']['T1'] = labels + + # return inputs def __preprocess_as_hf__(self, images, full_text): """ @@ -146,22 +225,19 @@ def __preprocess_as_hf__(self, images, full_text): inputs['attention_mask'] = {} inputs['labels'] = {} - # ========== 핵심 수정! ========== - # 두 이미지를 개별적으로 batch 차원 추가한 후 합치기 - # ref_image: [C, H, W, D] → [1, C, H, W, D] - # query_image: [C, H, W, D] → [1, C, H, W, D] - # 합치기: [2, C, H, W, D] → 이제 PatchEmbed가 batch=2로 처리 + # Process multiple images for interleave-style multi-image handling + # Each image: [C, H, W, D] = [1, 120, 120, 120] + # Stack along batch dimension: [num_images, C, H, W, D] + # Examples: + # - 2 images (1 ref + 1 query): [2, 1, 120, 120, 120] + # - 4 images (3 refs + 1 query): [4, 1, 120, 120, 120] + # + # This allows PatchEmbedInterleave to process each image independently + # The batch dimension here represents multiple images, NOT multiple samples + # Each will be independently processed through vision encoder + stacked_images = torch.stack(images) # [num_images, 1, 120, 120, 120] - processed_images = [] - for img in images: - # Add batch dimension to each image - processed_images.append(img.unsqueeze(0)) # [1, C, H, W, D] - - # Concatenate along batch dimension - batched_images = torch.cat(processed_images, dim=0) # [2, C, H, W, D] - - inputs['pixel_values']['T1'] = batched_images - # ================================== + inputs['pixel_values']['T1'] = stacked_images # Tokenize full conversation full_encoding = self.tokenizer( @@ -183,7 +259,7 @@ def __preprocess_as_hf__(self, images, full_text): # Apply instruction masking: mask everything except the LAST assistant's response # We want to train only on the final answer, not on the intermediate "Understood" response - # Find all assistant tokens + # Find all assistant start tokens assistant_pattern = "<|im_start|>assistant\n" assistant_tokens = self.tokenizer.encode(assistant_pattern, add_special_tokens=False) assistant_tensor = torch.tensor(assistant_tokens, device=input_ids.device) @@ -194,9 +270,26 @@ def __preprocess_as_hf__(self, images, full_text): assistant_positions.append(i + len(assistant_tokens)) if len(assistant_positions) >= 2: - # Mask everything before the LAST assistant response + # Mask everything before the LAST assistant response (including intermediate assistants) + # This includes: all user turns, all previous assistant responses last_assistant_pos = assistant_positions[-1] labels[:last_assistant_pos] = -100 + + # Additionally mask all PREVIOUS assistant responses (between first and last) + # Find <|im_end|> tokens to identify where each assistant response ends + im_end_pattern = "<|im_end|>\n" + im_end_tokens = self.tokenizer.encode(im_end_pattern, add_special_tokens=False) + im_end_tensor = torch.tensor(im_end_tokens, device=input_ids.device) + + # Mask intermediate assistant responses (between positions[0] and positions[-1]) + for assistant_start in assistant_positions[:-1]: # All except last + # Find the next <|im_end|> after this assistant start + for j in range(assistant_start, last_assistant_pos): + if j + len(im_end_tokens) <= len(input_ids): + if torch.equal(input_ids[j:j+len(im_end_tokens)], im_end_tensor): + # Mask from assistant_start to end of <|im_end|> + labels[assistant_start:j+len(im_end_tokens)] = -100 + break elif len(assistant_positions) == 1: # Only one assistant response (shouldn't happen in 2-turn, but handle it) labels[:assistant_positions[0]] = -100 @@ -206,8 +299,7 @@ def __preprocess_as_hf__(self, images, full_text): inputs['labels']['T1'] = labels return inputs - - + def __len__(self) -> int: return len(self.tasks) diff --git a/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py b/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py index 5d279d2..844e3f6 100644 --- a/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py +++ b/BLIP_MRI/project/utils/Trainer_LLaVaNextInterleave_comparison.py @@ -148,7 +148,7 @@ def compute_metrics(eval_preds): for pred in decoded_preds: pred_clean = pred.strip() # Extract first number - match = re.search(r'(\d+\.?\d*)', pred_clean) + match = re.search(r'(-?\d+\.?\d*)', pred_clean) if match: pred_values.append(float(match.group(1))) else: @@ -156,7 +156,7 @@ def compute_metrics(eval_preds): for label in decoded_labels: label_clean = label.strip() - match = re.search(r'(\d+\.?\d*)', label_clean) + match = re.search(r'(-?\d+\.?\d*)', pred_clean) if match: true_values.append(float(match.group(1))) else: @@ -207,8 +207,8 @@ def compute_metrics(eval_preds): # Pattern 3: just the label itself (e.g., "control", "patient") else: words = label_clean.split() - if len(words) > 0: - all_labels.add(words[-1]) # Take last word as label + if len(words) == 1: # Only single word + all_labels.add(words[0]) # Create label to idx mapping label_to_idx = {label: idx for idx, label in enumerate(sorted(all_labels))} @@ -822,7 +822,7 @@ def log_generated_result(self, model, inputs, mode="training"): generated_ids = actual_model.generate( pixel_values=pixel_values_sample, # Use prepared pixel_values input_ids=prompt_ids, - max_new_tokens=250, + max_new_tokens=150, do_sample=False, temperature=0.1, pad_token_id=self.tokenizer.pad_token_id,