Skip to content

Commit f29e4d9

Browse files
committed
Merge branch 'improve-sampler-api-access' into ycros-ropey-pt2
2 parents eed4b71 + 309534d commit f29e4d9

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

expose.h

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
#pragma once
22

33
const int stop_token_max = 10;
4+
// match kobold's sampler list and order
5+
enum samplers
6+
{
7+
KCPP_SAMPLER_TOP_K,
8+
KCPP_SAMPLER_TOP_A,
9+
KCPP_SAMPLER_TOP_P,
10+
KCPP_SAMPLER_TFS,
11+
KCPP_SAMPLER_TYP,
12+
KCPP_SAMPLER_TEMP,
13+
KCPP_SAMPLER_REP_PEN,
14+
KCPP_SAMPLER_MAX
15+
};
416
struct load_model_inputs
517
{
618
const int threads;
@@ -40,6 +52,8 @@ struct generation_inputs
4052
const int mirostat = 0;
4153
const float mirostat_eta;
4254
const float mirostat_tau;
55+
const samplers sampler_order[KCPP_SAMPLER_MAX];
56+
const int sampler_len;
4357
const char * stop_sequence[stop_token_max];
4458
const bool stream_sse;
4559
};

gpttype_adapter.cpp

+47-7
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,16 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
219219
candidates->size = last_idx;
220220
}
221221

222+
void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p)
223+
{
224+
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
225+
llama_sample_repetition_penalty(nullptr, &candidates_p,
226+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
227+
last_n_repeat, rep_pen);
228+
}
229+
222230
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
223-
int mirostat, float mirostat_tau, float mirostat_eta)
231+
int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX])
224232
{
225233
int id = 0;
226234
std::vector<llama_token_data> candidates;
@@ -231,11 +239,11 @@ int mirostat, float mirostat_tau, float mirostat_eta)
231239

232240
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
233241

234-
// Apply penalties
235-
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
236-
llama_sample_repetition_penalty(nullptr, &candidates_p,
237-
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
238-
last_n_repeat, rep_pen);
242+
// Run this except for when we are going to do the sampler reordering case below
243+
if (temp <= 0 || mirostat > 0 || sampler_len == 0)
244+
{
245+
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
246+
}
239247

240248
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
241249
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@@ -261,6 +269,37 @@ int mirostat, float mirostat_tau, float mirostat_eta)
261269
llama_sample_temperature(nullptr, &candidates_p, temp);
262270
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
263271
}
272+
else if (sampler_len > 0)
273+
{
274+
for (int i = 0; i < sampler_len; i++) {
275+
switch (sampler_order[i]) {
276+
case KCPP_SAMPLER_TOP_K:
277+
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
278+
break;
279+
case KCPP_SAMPLER_TOP_A:
280+
sample_top_a(&candidates_p,top_a,1);
281+
break;
282+
case KCPP_SAMPLER_TOP_P:
283+
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
284+
break;
285+
case KCPP_SAMPLER_TFS:
286+
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
287+
break;
288+
case KCPP_SAMPLER_TYP:
289+
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
290+
break;
291+
case KCPP_SAMPLER_TEMP:
292+
llama_sample_temperature(nullptr, &candidates_p, temp);
293+
break;
294+
case KCPP_SAMPLER_REP_PEN:
295+
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
296+
break;
297+
default:
298+
break;
299+
}
300+
}
301+
id = sample_token(&candidates_p, rng);
302+
}
264303
else
265304
{
266305
// Temperature sampling
@@ -1235,7 +1274,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
12351274

12361275
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
12371276
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
1238-
params.mirostat,params.mirostat_tau,params.mirostat_eta);
1277+
params.mirostat, params.mirostat_tau, params.mirostat_eta,
1278+
inputs.sampler_len, inputs.sampler_order);
12391279

12401280
last_n_tokens.erase(last_n_tokens.begin());
12411281
last_n_tokens.push_back(id);

koboldcpp.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from concurrent.futures import ThreadPoolExecutor
1010

1111
stop_token_max = 10
12+
sampler_order_max = 7
1213

1314
class load_model_inputs(ctypes.Structure):
1415
_fields_ = [("threads", ctypes.c_int),
@@ -47,6 +48,8 @@ class generation_inputs(ctypes.Structure):
4748
("mirostat", ctypes.c_int),
4849
("mirostat_tau", ctypes.c_float),
4950
("mirostat_eta", ctypes.c_float),
51+
("sampler_order", ctypes.c_int * sampler_order_max),
52+
("sampler_len", ctypes.c_int),
5053
("stop_sequence", ctypes.c_char_p * stop_token_max),
5154
("stream_sse", ctypes.c_bool)]
5255

@@ -186,7 +189,7 @@ def load_model(model_filename):
186189
ret = handle.load_model(inputs)
187190
return ret
188191

189-
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False):
192+
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=None, seed=-1, stop_sequence=[], stream_sse=False):
190193
inputs = generation_inputs()
191194
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
192195
inputs.prompt = prompt.encode("UTF-8")
@@ -205,8 +208,19 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
205208
inputs.mirostat = int(args.usemirostat[0])
206209
inputs.mirostat_tau = float(args.usemirostat[1])
207210
inputs.mirostat_eta = float(args.usemirostat[2])
211+
elif mirostat in (1, 2):
212+
inputs.mirostat = mirostat
213+
inputs.mirostat_tau = mirostat_tau
214+
inputs.mirostat_eta = mirostat_eta
208215
else:
209216
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
217+
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
218+
try:
219+
for i, sampler in enumerate(sampler_order):
220+
inputs.sampler_order[i] = sampler
221+
inputs.sampler_len = len(sampler_order)
222+
except TypeError as e:
223+
print("ERROR: sampler_order must be a list of integers: " + str(e))
210224
inputs.seed = seed
211225
for n in range(stop_token_max):
212226
if not stop_sequence or n >= len(stop_sequence):
@@ -272,6 +286,10 @@ def run_blocking():
272286
tfs=genparams.get('tfs', 1.0),
273287
rep_pen=genparams.get('rep_pen', 1.1),
274288
rep_pen_range=genparams.get('rep_pen_range', 128),
289+
mirostat=genparams.get('mirostat', 0),
290+
mirostat_tau=genparams.get('mirostat_tau', 5.0),
291+
mirostat_eta=genparams.get('mirostat_eta', 0.1),
292+
sampler_order=genparams.get('sampler_order', None),
275293
seed=genparams.get('sampler_seed', -1),
276294
stop_sequence=genparams.get('stop_sequence', []),
277295
stream_sse=stream_flag)
@@ -288,6 +306,10 @@ def run_blocking():
288306
tfs=genparams.get('tfs', 1.0),
289307
rep_pen=genparams.get('rep_pen', 1.1),
290308
rep_pen_range=genparams.get('rep_pen_range', 128),
309+
mirostat=genparams.get('mirostat', 0),
310+
mirostat_tau=genparams.get('mirostat_tau', 5.0),
311+
mirostat_eta=genparams.get('mirostat_eta', 0.1),
312+
sampler_order=genparams.get('sampler_order', None),
291313
seed=genparams.get('sampler_seed', -1),
292314
stop_sequence=genparams.get('stop_sequence', []),
293315
stream_sse=stream_flag)

0 commit comments

Comments
 (0)