9
9
from concurrent .futures import ThreadPoolExecutor
10
10
11
11
stop_token_max = 10
12
+ sampler_order_max = 7
12
13
13
14
class load_model_inputs (ctypes .Structure ):
14
15
_fields_ = [("threads" , ctypes .c_int ),
@@ -47,6 +48,8 @@ class generation_inputs(ctypes.Structure):
47
48
("mirostat" , ctypes .c_int ),
48
49
("mirostat_tau" , ctypes .c_float ),
49
50
("mirostat_eta" , ctypes .c_float ),
51
+ ("sampler_order" , ctypes .c_int * sampler_order_max ),
52
+ ("sampler_len" , ctypes .c_int ),
50
53
("stop_sequence" , ctypes .c_char_p * stop_token_max ),
51
54
("stream_sse" , ctypes .c_bool )]
52
55
@@ -186,7 +189,7 @@ def load_model(model_filename):
186
189
ret = handle .load_model (inputs )
187
190
return ret
188
191
189
- def generate (prompt ,max_length = 20 , max_context_length = 512 ,temperature = 0.8 ,top_k = 120 , top_a = 0.0 , top_p = 0.85 , typical_p = 1.0 , tfs = 1.0 , rep_pen = 1.1 ,rep_pen_range = 128 ,seed = - 1 ,stop_sequence = [],stream_sse = False ):
192
+ def generate (prompt ,max_length = 20 , max_context_length = 512 , temperature = 0.8 , top_k = 120 , top_a = 0.0 , top_p = 0.85 , typical_p = 1.0 , tfs = 1.0 , rep_pen = 1.1 , rep_pen_range = 128 , mirostat = 0 , mirostat_tau = 5.0 , mirostat_eta = 0.1 , sampler_order = None , seed = - 1 , stop_sequence = [], stream_sse = False ):
190
193
inputs = generation_inputs ()
191
194
outputs = ctypes .create_unicode_buffer (ctypes .sizeof (generation_outputs ))
192
195
inputs .prompt = prompt .encode ("UTF-8" )
@@ -205,8 +208,19 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
205
208
inputs .mirostat = int (args .usemirostat [0 ])
206
209
inputs .mirostat_tau = float (args .usemirostat [1 ])
207
210
inputs .mirostat_eta = float (args .usemirostat [2 ])
211
+ elif mirostat in (1 , 2 ):
212
+ inputs .mirostat = mirostat
213
+ inputs .mirostat_tau = mirostat_tau
214
+ inputs .mirostat_eta = mirostat_eta
208
215
else :
209
216
inputs .mirostat = inputs .mirostat_tau = inputs .mirostat_eta = 0
217
+ if sampler_order and 0 < len (sampler_order ) <= sampler_order_max :
218
+ try :
219
+ for i , sampler in enumerate (sampler_order ):
220
+ inputs .sampler_order [i ] = sampler
221
+ inputs .sampler_len = len (sampler_order )
222
+ except TypeError as e :
223
+ print ("ERROR: sampler_order must be a list of integers: " + str (e ))
210
224
inputs .seed = seed
211
225
for n in range (stop_token_max ):
212
226
if not stop_sequence or n >= len (stop_sequence ):
@@ -272,6 +286,10 @@ def run_blocking():
272
286
tfs = genparams .get ('tfs' , 1.0 ),
273
287
rep_pen = genparams .get ('rep_pen' , 1.1 ),
274
288
rep_pen_range = genparams .get ('rep_pen_range' , 128 ),
289
+ mirostat = genparams .get ('mirostat' , 0 ),
290
+ mirostat_tau = genparams .get ('mirostat_tau' , 5.0 ),
291
+ mirostat_eta = genparams .get ('mirostat_eta' , 0.1 ),
292
+ sampler_order = genparams .get ('sampler_order' , None ),
275
293
seed = genparams .get ('sampler_seed' , - 1 ),
276
294
stop_sequence = genparams .get ('stop_sequence' , []),
277
295
stream_sse = stream_flag )
@@ -288,6 +306,10 @@ def run_blocking():
288
306
tfs = genparams .get ('tfs' , 1.0 ),
289
307
rep_pen = genparams .get ('rep_pen' , 1.1 ),
290
308
rep_pen_range = genparams .get ('rep_pen_range' , 128 ),
309
+ mirostat = genparams .get ('mirostat' , 0 ),
310
+ mirostat_tau = genparams .get ('mirostat_tau' , 5.0 ),
311
+ mirostat_eta = genparams .get ('mirostat_eta' , 0.1 ),
312
+ sampler_order = genparams .get ('sampler_order' , None ),
291
313
seed = genparams .get ('sampler_seed' , - 1 ),
292
314
stop_sequence = genparams .get ('stop_sequence' , []),
293
315
stream_sse = stream_flag )
0 commit comments