From 2338c6fbd3de346c2d797f14b9c5caa429c71717 Mon Sep 17 00:00:00 2001 From: Radicat <2285225334@qq.com> Date: Wed, 2 Oct 2024 19:32:50 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80[NEW]=20Support=20vLLM.LLM=20;=20sm?= =?UTF-8?q?all=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DEVELOPMENT.md | 4 +- README.md | 4 +- asset/DevelopmentStatistics.xlsx | Bin 10417 -> 10382 bytes src/fastmindapi/model/llama_cpp/LLM.py | 1 + src/fastmindapi/model/openai/ChatModel.py | 1 - .../model/transformers/CausalLM.py | 4 +- .../model/transformers/PeftModel.py | 2 +- src/fastmindapi/model/vllm/LLM.py | 193 ++++++++---------- 8 files changed, 99 insertions(+), 110 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index d45de04..0d35646 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -18,11 +18,11 @@ https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py#L1836 # OpenAI - - https://platform.openai.com/docs/api-reference/chat/create +# vLLM +https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py#L312 # Model diff --git a/README.md b/README.md index 35f0e65..5d3af91 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ curl http://127.0.0.1:8000/model/add_info \ -H "Authorization: Bearer sk-19992001" \ -d '{ "model_name": "gemma2", - "model_type": "TransformersCausalLM", + "model_type": "Transformers_CausalLM", "model_path": ".../PTM/gemma-2-2b" }' @@ -129,7 +129,7 @@ client = FM.Client(IP="x.x.x.x", PORT=xxx, API_KEY="sk-19992001") model_info_list = [ { "model_name": "gemma2", - "model_type": "TransformersCausalLM", + "model_type": "Transformers_CausalLM", "model_path": ".../PTM/gemma-2-2b" }, ] diff --git a/asset/DevelopmentStatistics.xlsx b/asset/DevelopmentStatistics.xlsx index 46050750df7b47a064f741d5b40dc4d4807ee6c6..d7caf5859c407161ca16bed94ed433aa9c9e39cc 100644 GIT binary patch delta 4347 zcmZ9QXHXN$*T#cLQ$Sh}14K&b0Ya}*r9&tJ3B5>>-hy;O2Ne_u(t9t0D2fz8>C&VV zAO?^wAc*vuzxUpE#+i3NoY^^N_k20?o9EemWBA;#{*wH9>1){`3=st@8t1tzNfRbK zil#ZJOKvfzTY{oIqI+?lxAf*JBGMn}BtTjqKsoZG1-Udq9<$rAIkk2eZf8`W(;B7! zsS8$ej()Y+32!VVTwgalXbtBh*92dG>OcvesG|w&7?$Z*Xd~bXc_s)r(h6 z$I7Pjw1XieVtcnrve;^s1=u0G(T{dga>{J1>Nex%f0osx46li~n3RF6+!B(_%yi7* zhZNA+^mg3hCHKx2=_~F=Jsk~%hRlL*ZK=SxlrNDA_VdHlriR3H_X$Mi$OkYa>qP-O)HZ_JI>L{sojeT#62XVbB z-N-OEy`)=BjyA);9Q!H3pr6cnvey2UwL$LyEtUk@Tc!Y98(CNJ=?x%R4vDUSfeL~WqSGjrG0 z5KS>0+#g3JpRgHZnM42n8};;gOB&GIo0ZUXEKrLD8Z^2y*Q_n37fVjraIo}X#FG-m z2Z*~G(7B%zjJBEItRf%VFF}4qMbZU#xildoaWFUhDCH-w$sWz@g~piI-#uB}RU7@2 z;y0n;R@tf&eSJ5_8Ah+Wf>^@nEh2vx$^~(*z@o2Q3teQTvll;$cssSmEvxp2_BAGU zYt;_rIk+&kRvT=R&|cleu8fj?KUp(`3UsbUp)pIe^9a9|r%wQmKZ&okd-eC$-F;|L z#V_U8B|_}p{brrXGaI+lt&15Nw&}0_g}D2%G*nz(@5y#iq{%MzmO;{~-EmX%oEOXH zVI4(RQ1$Qdz_SIGPf7<=_aNr1iB0Rp@Q5U{e%KN+`l0Q z{}9W7zrCSI^GzTaX$0x|ZOhW96u}3J5L5c_?*44IL9y~!q`X=H{y?W?GkjF;yYqX97N5 zq4>^GK|4Qw?GuP_(p70-0bdu;21arI=I9hCn^v8GE&+!%eYmH^P zrBb>R@qH!6E7eMGW>)w+J;g-8z`LVV-<4iR5AX3ZT56U?aKTb*oqbcsf_o|y(S!r+ zLa-3mon>5b<9s8>TR`I$hP2Mw*az=E%!IVo&jzU_ISR{$a-(JXs*zRn1*J<}#p@_O z-r#`PHJm$ppqH2wV=b5jneNm6@YJ$_l75*sCc0R(paUSWi^WNJ=c=wucN=ekk!@)X zZgSQHs@O&77;pM*s}WxjBi-{N|Ht4sJKJ+*nwXu>kMRQqjKjX^Jnu@iVlJN&HgBqvkekW-3VB?b2Ak3=0ONLGF$;;?0RxF%6Ei~-$N zWbqOpOS12tsW*hX zns)03{At1~Z0LE5$pa~Mt6GH0+gr93&*$R>$B5&i$L2pvqSBJGB`B>rC8C1Gt5We+ z_sN1b(ClH}X$PQT3qFPt|_@LiYqrJekhUj{G~eXpHd0SMeTOL)_{+ypsd z$zr6l98_dz^5kpHabdpF7uD*+Etah7+*y`Ef2`#jJnlU#DSv3tYBMV<5}L2a&bCV^ zqlxJlG_BmcQ>3#kUh~0x%1Y3Lf-~seYdDCzEtF_bv=SJ8IO1 z=3CsS$$!LY7K3ii&5fTJ%zAiz8Up)9beK4rqf0f?mZ@9OFzWm2nT;G4!($0I{ZKWE zoZYEk^p53Yc4Ypwe`y&b7Yz!Z${^YnIvU-p#pZk5X^3=WuG7E9`HL)HWZ&I+vlXM` zwdP+RqG4@sUhaX#7J)uW8AD5QR|OIDf~t5kHi|eKgHHD>*fC;inE=3~5HmDX(7qtM zk>GFElt%&ptOGGhcSO-OjV`U)t(DnfmU3&!$ezquIL3Vrz3}Fa_=6Vy>ihM2q4~%6 z)|s4c%bAgk6ebz?!#7|YM3a`lV(ZnFtqfd8=JnpUEw(SK*?B@>D@zy4@w28scp6CS ziRfH0nD_68#VM(5q8*$nL6v(2MlStOp+kQ0BvR88HnUhC21j&}Lov6YUA4YlV;CqdA}7(|hTCo6nToVd;*I-#sG1|Pyp0V$Y&?OHbbw)10vxTE}x z%SZbQuDnEacqxPgJ5V`Jox;+X|6-yb29vy(kIE%zF=fYr@?+f=Z}Nv1aLI0(?cIb1 z@j$oi>MVm>Xwho$@ZkwVl!R7a^!9IV@X&~nEYm~`wWkn9d|;_xzFixQ**9aa?-Hs} zQq98X$hMWx)8&s`5S?+`x>2u1Iu4Q;e-!TAx zM^(chWo~YPGw%T9Jy^8_B`-T@{mjcMG;aJz1({ugcu&nr;0(~Pac_}1qdb2L z{@O~naE3imCev|RG&fmLv%@0ApoTqDDuu(|=4u#dHiP+l&foL9Fk2S+}PXDWyz50fGMQ98nZAu+?wcL_!r`?_=GTGfL>6c6k{ z6y6V}>wH!*NNtDPdY$+ZY4q+bVGN7bf1F=?B6E3iy62;qx4W*2ZoPWtdEP5gvo~^j zO#Ugn?elKfghb8Y!e0304J;0O3Qkc)KiRg3`R>K01XjLMkU#u+T;wI3z~Wr0=BTv# zgh9p{Wt1$T@i9u?4+4o%u>6ph!g%quJd(YfC%}*-U}QyE?z<4Xoa?IMgU`I^0S3cZ zzUkq2vVEoQ{Rzwp>k$$MRWbVy?(8{2GRZAINqA+C%s^V3b(YXQvTJI=;^QLhqh5Y@ zWg@S$^bG49!huK9%>K0{H(tEycgtg7a>_TUlbA-jS*eHKp>SLYQz3S&sFZup^oOF%xC-y z&7}r@uRB?$k!u49!kiuVGn)}pmd4G^rAmJ52pfkl5o+mjBq>h})uWX$;;&L)r!NXpfQ428HbX?E*&VA}=3lAHz_DyYph+sGb$Q&Ya1JEs)4wH_+B-#+x zZJ#Esdl$=P>L`Bgp;8!3*RfoSWzIIVN4mqQN$c{Ac2{!gDw0l&&om%&Ko#%C`h6JMhnHUy-Ik}t`;9a04V>NJo}rY|dXSfC*(+7t2V}P!f7~?kx9x z=RHLXKez+UlQCu4OauU|kpTb<|L4aZ+4(v;JV5!ncsl=U%U-=ycKxPAqjpTRpnRY( zfiZ({Yl_;j*=DkvNxK=QxI(_~Y{C)!`Roo^Rw_b^-#FLzNPJ^aA@24cQ8(WiK06C0 zJhR~+IhtlbFa@HFs}Rr1_>;}CNt1|8^hD2jMn$zP(dK9APHooYDYLNO#SZBlU+EC8 zG<#m}*B&d~{OAqlQ!+AzDe@q;mO%*JEMl^hyAXn0OD5ggDCR(hM{Nr`;aHqzj1w?JbpwIq}xEl>%j zASVCHXsVQTP2!J~Iux-2i!~N3*N`ivpXuza-Uga$w7#)H)!Eh0+fq-or57IZUGfbT zr{_!(n@sMBHQn4;=1*>VpN(XLnfrdJL2Va!T0K>8aS9tuxqu`sTKgvNOzWfrL{(py znT`vc6bw)bupfOTJC?c0Z{kUvH;K=U@bYSRP9j|SNL_8{`-klB?-tnT>%W%F@%Wg3 zEW}EfI8^*s)<_~7UWUUPM%cKC&^9_5f@e<4C0&UY~u17B}? zDCx}O<9#=mDbWMR{j>B1?Qi;vx9#Z7nqC$ER{MzRsD@m#LTj9>VCEDYs}Nl2*1 zTTcR(Y4MPpJoQJ03PvnwvJgOLE|kKDJ@65t7b=q%#@og6`kuRz0>+Lml1W;i3I|n9 z1+x{0qq)@+$ptkU2u<<)XBYH13yJ^xXV~MGx$*Z= zgAwP^A=1V~@Nl#IH4^{;(;u?_u308bCl5dK|3~ldfmawZULNLuXZHuqkZS+{?cd^` VT?R&rSAmEHlfcVKdcyTj>3?{(Kmq^& delta 4365 zcmZ9Qbx_m|^T3aiMp_!AyW^BjX^sv_kv>uy^h=y{$Q>Pmf*=h?cS(15w}5~k9k1{6 z{$`$;-~O@x?9R^4?CkDmH%_}!yW$=T-G;C22?7P17I1!-D@ypVlJw6$8Yx8M2Uud0ok$+E&4y8*@S zh03^cb+cUOk*FfG_&?>*LnmGu$`j;ywmh^4e0qD?E*~ED{Rgim;gxLoAQOQZH4d@o za*TWq^1u(W#W8Yjt{!SetNvMKD469VJ4cd6EcfI&M83#il?ir7j8|)-|3brPMWi|@ zfKbzkF}RhzIF@28{f@Xk?4-cOdaGm}^aEMMbv+UK(*zMd=`{#8o*}#-Xj+Ktnhm8F zwW27CY*TM}!u@`xx8@;TZS7tcg!0wBBynfkCDZ`?$R|ZSB3|z+%+ga@u z*Io7$dDBiTu&wum!;&6V=NSk3`jfXUo+8NkK-SnIy;3I;B0(NbmFP;T5SGOWAL{1C zA7=WYWErBho}TX#&uC|H8X3)ao-u5p(9uZpRhVvLcSbm>_^B~xRTzX2JaxbSzTu1VY2e&d9P^i$sztuA#r(T-Jd!6gcoY)teSn(J&DUldJ&ed23T_Oeh zI!YX)_gtbP2=%^-g|O@{pVBj{c_d+?N6sOZA0MSS+S9^AxsRD=*B*r%Fib4^gh+9= zAT&3=MVwaBwJ)BYZaSsMnTns$c4-YK014ocKCH}kF~-BW56N?~&`noDyvxqa1kGej zeD=M{Uost|>Ox3r?m4`sBJjAc$Wwax;Z^+5qYNCQW$#O}(Hn&w&m{}E6LWdJa)9@g zaVGl9!XA4?=Idm^&3>eVaJPmvfEft6oX#wJ!zg2X8x(Z>8d-VJn~zVLg=ZHQiZ&Tx zJu0xoT(wN3$y%6^)mL<4NZ*_nI3=wSd6*Nde|&KOiJiB)tZ!IwrTO`%47+cV~ zH>)&lGUp;KV=tQ8#@v7V)ux*51!VBj(IeHea*Z<;Qf^KmAdOATwij*0+&y!yz>$Bo!=Kc=iaYt!;8?&q2{%N$-qhCN( zt%0K;*=aF2iAg8529_(a&&2%dZrkyNvQeu-$PgCGo6@&wN}p-6hH-nwV^F6CP9bmf zG~3N`H-Mw(ju*M_Jm78yNOhVZ(`Xw?m6KjKCN3moLIt(INyQ^@(7Sxk{FIi(kGwZU zVZtu%6(yH~12a@VV(vU6^7MP4jVU9}d1Jl^9ZQU$3zj8z%vN6->UO>P>BF$rG4QB6 zV#ChnX*N2uNFI^W?#<({*RL3LVC$sW{@Y^$4zN3NFOQF-WmOeNsb;-pBQu;Je9aA& zrkWl-&9vI=PIGS^v5UFLZ?MvnRJBZoG@c^m_igwE=w!h-L(v>~F+2BfTF~ja7~W0A zeVJKC`j3n+xZgiet178*Y7fPGjb5nB6H;PcbXbupej$NbQTWOCo+V0;rf=sdyMLUV z1Mm``8l_WBq}=2&y;}~odFB=I}?DQ&B?KIX!+8fR9H4p=9UG5z2+QOSQI?i}| z&Al3PD{Y8B4`v_FhWUZhb}M*UN}4(z=o1iyT-Y|dJ7K}mb^UplubZGG9K2w!B+72i z?41ofIPZA|(#+oC`J!mKm)=z9DVDY%pNFZ;ZEyoaWnntu;_(^mYLnyzX7v9;5>j`@Nx&08(0MCj;q z18Hl|4OybYrbR%S>6tCq_qCEgma7X@lG9mFOMU>P$-;O^NP=GdIm7YeH$vF_rh)RG zys?jtN4H5k?1K`8;=5aRQ zd?~{*egXTrVpoOjmZyJto%8bf^Kzf*@a`v+o(vo#V^LAj!||2#EkLL)a^Sf{y8dWC z3yG=%)cbgJORR}R8NHr(RQ0Mate|p(S+q?$i$;m9K-Y(G0pYfWUnw1N*WcfM8OO%P zHreNWv`JsBwj)`VI{hXKeEtyUqo-3Ju@j5LzdCc_6|IrQx0ruwc`(N$R*`)ZCz(p% zY`2ENP6O*0qgQ|5D~a9QUnu=Yl$ z&jnz$l2X{Dk+VplLL>0=rQ|M!?XI{nhE`52#7lFJpAluslf5uk)8= zB05~l)v#s)YHYP4o-RmtC~Pk90N3o1mezrXIN9hqgew!g9_`)N>Qw3T)p|Md$c(2U zkFV7!a)LqDHcCwUr{oGFxtS#S55FO=02U;{I>o;1FXY}LC!JA&Q^$~^f)&5t`HE^iaVDc>SsCs8-;U?90G%{8FO5> zryZElSPTAY=1ubV46dtZk&6tQCOo90I(*r~>=0D|tn!R9nhLmdlPs;bW^yOhsICCf z(+mF8=EK)cVrH`X+gmu_{?1E^f)VGo;lk}7Mfr+%(olbmTu+IYgXYn$W`V&I7jj1F zgN8K(w2w?a_LIF~?~Qk>Qv3`zs2{0T_bvPJObdU8vma79g|Tqvyv-O`DL%4utLZ3M z0UMGjZAyOXd~b3R`$k;aWpNEmX`2Q(UH83zb5W-2t9jL9nIzs?EEC_(>km7y4d0WW z6q6@9ly^Ur%UuUY=ECUua(l(Lv^#{{?gn#CwV${$@$1I7SIWm|$-`PRfxmn6@$nxO zW2mizDuCC5Q}xWK3Fm7SKa<)i4c=yqbX3<@^VP(BXL!%y!}=LZbDjfWzTbOUS9kE> zE9r7OcNjA6i|>LGhqR z=<&^k-C`YE-l5<(#kVqBap*7TstwtnW^RxTyw9(OsFz`#BckBvK_TauU%!$}FxhB$ zm4R*AxrO_63{fO?^w0rPJ?@ODEHecIW@{B+r<)*wb8)q8v4Ici8wDJgdG!o&UZm+t zc3YA|KVCjnyei)-v0voXv26Z52bO(dnau06lfD{bqDUgg)K9^3XW*VBAfTN=tx+!F z^+F#1KEa*L{tf=l2GN_JAA|8YhTzhno6ZD$LE?6`%EFw;a8sBON@u{6EszksHe1~t zUfe8r`;#w9qc-sJ9Bd-<*DAO4sdEUG^p@1GIV=2GP5{*RGZsQIf{a1QN7>32MJUE# z@N5;^&3c$A4_rhtQ-jM#`JI^8v|+Lv)i>L-3uYLxM1obBXCr==*Qnc(Uk;!7J0CXG zt!&Ti=92zKl)-m0{#;*U-LQas*&vnB^{YI!qNS%$wg+f+?jR|FLUd%DO~rQdtz1_f zLpjTI$@sgPPnsT#-`=5rf1iLlrqIVU|2G|o^B5&Tm|qFcfwLxe4o_R|1LB&t1Gsvf zfL4kE0_9?YK&1Z%aQ0RnHg9!3Jse!@{^7Wiqz=a=e!{ThkFBBOEyd0RY8GJ-fx&brf2}(t_ONToChdf{SKDEBeX5$+(j7Zc_IsQE^cy|2=oW=i~Y^j*{G*ypbI7oxY9YL zHPv3JaCCutn-cHOTTtZ?eU|V~o(Ze|{uf_7Umr*WqMIr*j5xx9L7pMJi@UF|f`AGG zz{l9|)!NA`>?2*(gRXgRR1Q&pmkwEaxim@4+qjc_5+L@6qHL@(g{)g;1>*osVbup= z9p)otWjc8uvo{Gs!7TZWr2G4lM?>{gfxT-X-NYKB+_KL+Zr+)~4AQ5h^&?5h-JS{L z=`%Puc@@f7ZbJ+-_ot+nZ)x7!lHaSzTP16{@v*=^&!`cB(yV?ka``=L&O^4VD!1T{ z102LjTy^MdK6T7%;kVg}n{BSyJ&#R=bia-t9LDSz)B4(08j>6)~`) z0NQF@{G=aOF&^dn12XbyyFS*UB}u*@U4jJS_8*Z$5pad3`+!OB;v|l%)XaBnneS9; zC8qRGuzOTXJrurDVC2-3>l}GhFPC&ulN$9VS|%BM%4IUD<}NE2T!EC~OHeKcxf4Uz zR0ecbw(O5hJhg6(mNK0x5PoE4 z^{ma=hssHyMd4WlrflH;g8-aNIcWd;M2xV$!1?#?K@hR2p`;jj&;7dP2?mU%&qVrW-}f diff --git a/src/fastmindapi/model/llama_cpp/LLM.py b/src/fastmindapi/model/llama_cpp/LLM.py index 0035a82..8eb8a8d 100644 --- a/src/fastmindapi/model/llama_cpp/LLM.py +++ b/src/fastmindapi/model/llama_cpp/LLM.py @@ -7,6 +7,7 @@ class LlamacppLLM: def __init__(self, model): self.model = model + self.model_name = None @classmethod def from_path(cls, diff --git a/src/fastmindapi/model/openai/ChatModel.py b/src/fastmindapi/model/openai/ChatModel.py index 55a5f5d..20c0888 100644 --- a/src/fastmindapi/model/openai/ChatModel.py +++ b/src/fastmindapi/model/openai/ChatModel.py @@ -11,7 +11,6 @@ def __init__(self, self.client = client self.system_prompt = system_prompt self.model_name = model_name - pass @classmethod def from_client(cls, diff --git a/src/fastmindapi/model/transformers/CausalLM.py b/src/fastmindapi/model/transformers/CausalLM.py index 4809c78..dab3cd2 100644 --- a/src/fastmindapi/model/transformers/CausalLM.py +++ b/src/fastmindapi/model/transformers/CausalLM.py @@ -8,8 +8,10 @@ def __init__(self, model): self.tokenizer = tokenizer self.model = model + self.model_name = None + self.model.eval() - pass + @classmethod def from_path(cls, diff --git a/src/fastmindapi/model/transformers/PeftModel.py b/src/fastmindapi/model/transformers/PeftModel.py index b2b8e2f..6281145 100644 --- a/src/fastmindapi/model/transformers/PeftModel.py +++ b/src/fastmindapi/model/transformers/PeftModel.py @@ -6,7 +6,7 @@ def __init__(self, base_model: TransformersCausalLM, self.raw_model = base_model.model self.tokenizer = base_model.tokenizer self.model = peft_model - pass + self.model_name = None @classmethod def from_path(cls, base_model: TransformersCausalLM, diff --git a/src/fastmindapi/model/vllm/LLM.py b/src/fastmindapi/model/vllm/LLM.py index 7bd486d..b44e731 100644 --- a/src/fastmindapi/model/vllm/LLM.py +++ b/src/fastmindapi/model/vllm/LLM.py @@ -4,33 +4,25 @@ class vLLMLLM: def __init__(self, - tokenizer, model): - self.tokenizer = tokenizer self.model = model - # self.model.eval() - pass + self.tokenizer = self.model.get_tokenizer() + self.model_name = None @classmethod def from_path(cls, model_path: str): from vllm import LLM - return cls(AutoTokenizer.from_pretrained(model_path, trust_remote_code=True), - AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto")) + return cls(LLM(model=model_path, trust_remote_code=True)) def __call__(self, input_text: str, max_new_tokens: int = None): - import torch - inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) - with torch.no_grad(): - outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) - full_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - # output_text = full_text[len(input_text):] - re_inputs = self.tokenizer.batch_decode(inputs.input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - output_text = full_text[len(re_inputs):] + from vllm import SamplingParams + outputs = self.model.generate([input_text], SamplingParams(**({ "max_tokens": max_new_tokens } if max_new_tokens else {}))) + output_text = outputs[0].outputs[0].text return output_text - + def generate(self, input_text: str, max_new_tokens: Optional[int] = None, @@ -38,70 +30,61 @@ def generate(self, logits_top_k: Optional[int] = None, stop_strings: Optional[list[str]] = None, config: Optional[dict] = None): - import torch - import torch.nn.functional as F - - inputs = self.tokenizer(input_text, return_tensors='pt').to(self.model.device) # shape: (1, sequence_length) - input_id_list = inputs.input_ids[0].tolist() - input_token_list = [self.tokenizer.decode([token_id]) for token_id in input_id_list] - - with torch.no_grad(): - generate_kwargs = {"generation_config": clean_dict_null_value(config) if config else None, - "max_new_tokens": max_new_tokens, - "stop_strings": stop_strings} - outputs = self.model.generate(inputs.input_ids, - **clean_dict_null_value(generate_kwargs), - tokenizer=self.tokenizer) - full_id_list = outputs[0].tolist() - full_token_list = [self.tokenizer.decode([token_id]) for token_id in full_id_list] - full_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + from vllm import SamplingParams + sampling_kwargs = { + "max_tokens": max_new_tokens, + "logprobs": logits_top_k if return_logits else None, + "prompt_logprobs": logits_top_k if return_logits else None, + "stop": stop_strings, + "repetition_penalty": (config["repetition_penalty"] if "repetition_penalty" in config else None) if config else None, + "temperature": (config["temperature"] if "temperature" in config else None) if config else None, + "top_p": (config["top_p"] if "top_p" in config else None) if config else None, + "top_k": (config["top_k"] if "top_k" in config else None) if config else None, + } + outputs = self.model.generate([input_text], SamplingParams(**clean_dict_null_value(sampling_kwargs))) - # output_text = full_text[len(input_text):] - re_inputs = self.tokenizer.batch_decode(inputs.input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - output_text = full_text[len(re_inputs):] + output_text = outputs[0].outputs[0].text + full_text = input_text + output_text + + input_id_list = outputs[0].prompt_token_ids + output_id_list = list(outputs[0].outputs[0].token_ids) + full_id_list = input_id_list + output_id_list + + full_token_list = [self.tokenizer.decode([token_id]) for token_id in full_id_list] + input_token_list = full_token_list[:len(input_id_list)] logits_list = None if return_logits: - # 获取模型的输出 logits - fulls = self.tokenizer(full_text, return_tensors='pt') - with torch.no_grad(): - logits = self.model(**fulls).logits # shape: (batch_size, sequence_length, vocab_size) - probabilities = F.softmax(logits, dim=-1) # shape: (1, sequence_length, vocab_size) - - # 使用 torch.topk 在 vocab_size 维度上获取 top_k 的 logits 和 token IDs - topk_logits, topk_tokens = torch.topk(logits, k=logits_top_k, dim=-1) # shape: (batch_size, sequence_length, top_k) - topk_probs, topk_tokens2 = torch.topk(probabilities, k=logits_top_k, dim=-1) # shape: (batch_size, sequence_length, top_k) - assert(torch.equal(topk_tokens,topk_tokens2)) - - # 提取 batch_size 和 sequence_length - _, sequence_length, _ = topk_tokens.shape - assert sequence_length == len(full_id_list) - - # 遍历每个位置,打印 top_k 的 token 和 logits + import math + raw_input_logits_list = outputs[0].prompt_logprobs + raw_output_logits_list = outputs[0].outputs[0].logprobs + raw_logits_list = raw_input_logits_list + raw_output_logits_list + logits_list = [{"id": full_id_list[0], "token": full_token_list[0]}] - for i in range(sequence_length-1): - token_id = full_id_list[i+1] - token = full_token_list[i+1] - # print(f"Position {i} (Token: {repr(token)}):") + for i in range(1, len(full_id_list)): + token_id = full_id_list[i] + token = full_token_list[i] + raw_info_dict = raw_logits_list[i] logits = { "id": token_id, "token": token, - "pred_id": [], - "pred_token": [], - "logits": [], - "probs": [], - # "logprobs": [] + "pred_id": [None]*logits_top_k, + "pred_token": [None]*logits_top_k, + # "logits": [], + "probs": [None]*logits_top_k, + "logprobs": [None]*logits_top_k } - for j in range(logits_top_k): - pred_token_id = topk_tokens[0, i, j].item() - pred_token = self.tokenizer.decode([pred_token_id]) - logit = topk_logits[0, i, j].item() - prob = topk_probs[0, i, j].item() - # print(f" Top {j+1}: Token ID={pred_token_id}, Token={repr(pred_token)}, Logit={logit:.4f}, Prob={prob:.4%}") - logits["pred_id"].append(pred_token_id) - logits["pred_token"].append(pred_token) - logits["logits"].append(round(logit,4)) - logits["probs"].append(round(prob,4)) + for chosen_token_id in raw_info_dict: + raw_info = raw_info_dict[chosen_token_id] + rank = raw_info.rank + if rank <= logits_top_k: + logprob = raw_info.logprob + decoded_token = raw_info.decoded_token + + logits["pred_id"][rank-1] = chosen_token_id + logits["pred_token"][rank-1] = decoded_token + logits["probs"][rank-1] = round(math.exp(logprob),4) + logits["logprobs"][rank-1] = round(logprob,4) logits_list.append(logits) generation_output = {"output_text": output_text, @@ -111,7 +94,8 @@ def generate(self, "full_id_list": full_id_list, "full_token_list": full_token_list, "full_text": full_text, - "logits": logits_list} + "logits": logits_list + } return generation_output @@ -121,55 +105,58 @@ def chat(self, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, stop: Optional[list[str]] = None): - import torch import time - - # 将消息列表转换为输入文本 - input_text = "" - for message in messages: - role = message.role - content = message.content - input_text += f"{role}: {content}\n" - input_text += "assistant: " - - inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) - - generate_kwargs = { - "max_new_tokens": max_completion_tokens, - "stop_strings": stop + from vllm import SamplingParams + sampling_kwargs = { + "max_tokens": max_completion_tokens, + "logprobs": top_logprobs if logprobs else None, + "stop": stop, } - - with torch.no_grad(): - outputs = self.model.generate(**inputs, - **clean_dict_null_value(generate_kwargs)) - - full_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - re_inputs = self.tokenizer.batch_decode(inputs.input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - output_text = full_text[len(re_inputs):] - # 对output_text进行后处理 - if output_text.lower().startswith("assistant:"): - output_text = output_text[len("assistant:"):].strip() + outputs = self.model.chat(messages, SamplingParams(**clean_dict_null_value(sampling_kwargs))) + + openai_logprobs = None + if logprobs: + openai_logprobs = [] + for token_prob in outputs[0].outputs[0].logprobs: + probs = { + "token": token_prob[next(iter(token_prob))].decoded_token, + "logprob": token_prob[next(iter(token_prob))].logprob, + "top_logprobs": [None]*top_logprobs + } + for chosen_token_id in token_prob: + rank = token_prob[chosen_token_id].rank + if rank <= top_logprobs: + top_prob = { + "token": token_prob[chosen_token_id].decoded_token, + "logprob": token_prob[chosen_token_id].logprob + } + probs["top_logprobs"][rank-1] = top_prob + openai_logprobs.append(probs) choices = [] choices.append({ "index": 0, "message": { "role": "assistant", - "content": output_text + "content": outputs[0].outputs[0].text }, - "finish_reason": "stop" + "logprobs": openai_logprobs, + "finish_reason": outputs[0].outputs[0].finish_reason }) + prompt_token_length = len(outputs[0].prompt_token_ids) + completion_token_length = len(list(outputs[0].outputs[0].token_ids)) + response = { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), - "model": self.model.config.name_or_path, + "model": self.model_name, "choices": choices, "usage": { - "prompt_tokens": inputs.input_ids.shape[1], - "completion_tokens": sum(len(self.tokenizer.encode(text)) for text in output_text), - "total_tokens": inputs.input_ids.shape[1] + sum(len(self.tokenizer.encode(text)) for text in output_text) + "prompt_tokens": prompt_token_length, + "completion_tokens": completion_token_length, + "total_tokens": prompt_token_length + completion_token_length } } return response \ No newline at end of file