Skip to content

Commit 2560b35

Browse files
authored
Merge pull request #665 from nhaghighat/add_logit_bias_sampler
Add Logit Sampler
2 parents f4b9657 + 1f83d19 commit 2560b35

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed

llama-cpp-2/src/sampling.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::fmt::{Debug, Formatter};
77
use crate::context::LlamaContext;
88
use crate::model::LlamaModel;
99
use crate::token::data_array::LlamaTokenDataArray;
10+
use crate::token::logit_bias::LlamaLogitBias;
1011
use crate::token::LlamaToken;
1112

1213
/// A safe wrapper around `llama_sampler`.
@@ -376,6 +377,42 @@ impl LlamaSampler {
376377
let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_greedy() };
377378
Self { sampler }
378379
}
380+
381+
/// Creates a sampler that applies bias values to specific tokens during sampling.
382+
///
383+
/// # Parameters
384+
/// - ``n_vocab``: [`LlamaModel::n_vocab`]
385+
/// - ``biases``: Slice of [`LlamaLogitBias`] values specifying token-bias pairs
386+
///
387+
/// # Example
388+
/// ```rust
389+
/// use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
390+
/// use llama_cpp_2::sampling::LlamaSampler;
391+
///
392+
/// let biases = vec![
393+
/// LlamaLogitBias::new(LlamaToken(1), 1.5), // Increase probability of token 1
394+
/// LlamaLogitBias::new(LlamaToken(2), -1.0), // Decrease probability of token 2
395+
/// ];
396+
///
397+
/// // Assuming vocab_size of 32000
398+
/// let sampler = LlamaSampler::logit_bias(32000, &biases);
399+
/// ```
400+
#[must_use]
401+
pub fn logit_bias(n_vocab: i32, biases: &[LlamaLogitBias]) -> Self {
402+
403+
let data = biases.as_ptr().cast::<llama_cpp_sys_2::llama_logit_bias>();
404+
405+
let sampler = unsafe {
406+
llama_cpp_sys_2::llama_sampler_init_logit_bias(
407+
n_vocab,
408+
biases.len() as i32,
409+
data,
410+
)
411+
};
412+
413+
Self { sampler }
414+
}
415+
379416
}
380417

381418
impl Drop for LlamaSampler {

llama-cpp-2/src/token.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::fmt::Display;
55

66
pub mod data;
77
pub mod data_array;
8+
pub mod logit_bias;
89

910
/// A safe wrapper for `llama_token`.
1011
#[repr(transparent)]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//! Safe wrapper around `llama_logit_bias`.
2+
use crate::token::LlamaToken;
3+
4+
/// A transparent wrapper around `llama_logit_bias`.
5+
///
6+
/// Represents a bias to be applied to a specific token during text generation.
7+
/// The bias modifies the likelihood of the token being selected.
8+
///
9+
/// Do not rely on `repr(transparent)` for this type. It should be considered an implementation
10+
/// detail and may change across minor versions.
11+
#[derive(Clone, Copy, Debug, PartialEq)]
12+
#[repr(transparent)]
13+
#[allow(clippy::module_name_repetitions)]
14+
pub struct LlamaLogitBias {
15+
logit_bias: llama_cpp_sys_2::llama_logit_bias,
16+
}
17+
18+
impl LlamaLogitBias {
19+
/// Creates a new logit bias for a specific token with the given bias value.
20+
///
21+
/// # Examples
22+
/// ```
23+
/// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
24+
/// let token = LlamaToken::new(1);
25+
/// let bias = LlamaLogitBias::new(token, 1.5);
26+
/// ```
27+
#[must_use]
28+
pub fn new(LlamaToken(token): LlamaToken, bias: f32) -> Self {
29+
Self {
30+
logit_bias: llama_cpp_sys_2::llama_logit_bias {
31+
token,
32+
bias,
33+
},
34+
}
35+
}
36+
37+
/// Gets the token this bias applies to.
38+
///
39+
/// # Examples
40+
/// ```
41+
/// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
42+
/// let token = LlamaToken::new(1);
43+
/// let bias = LlamaLogitBias::new(token, 1.5);
44+
/// assert_eq!(bias.token(), token);
45+
/// ```
46+
#[must_use]
47+
pub fn token(&self) -> LlamaToken {
48+
LlamaToken(self.logit_bias.token)
49+
}
50+
51+
/// Gets the bias value.
52+
///
53+
/// # Examples
54+
/// ```
55+
/// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
56+
/// let token = LlamaToken::new(1);
57+
/// let bias = LlamaLogitBias::new(token, 1.5);
58+
/// assert_eq!(bias.bias(), 1.5);
59+
/// ```
60+
#[must_use]
61+
pub fn bias(&self) -> f32 {
62+
self.logit_bias.bias
63+
}
64+
65+
/// Sets the token this bias applies to.
66+
///
67+
/// # Examples
68+
/// ```
69+
/// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
70+
/// let token = LlamaToken::new(1);
71+
/// let mut bias = LlamaLogitBias::new(token, 1.5);
72+
/// let new_token = LlamaToken::new(2);
73+
/// bias.set_token(new_token);
74+
/// assert_eq!(bias.token(), new_token);
75+
/// ```
76+
pub fn set_token(&mut self, token: LlamaToken) {
77+
self.logit_bias.token = token.0;
78+
}
79+
80+
/// Sets the bias value.
81+
///
82+
/// # Examples
83+
/// ```
84+
/// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias};
85+
/// let token = LlamaToken::new(1);
86+
/// let mut bias = LlamaLogitBias::new(token, 1.5);
87+
/// bias.set_bias(2.0);
88+
/// assert_eq!(bias.bias(), 2.0);
89+
/// ```
90+
pub fn set_bias(&mut self, bias: f32) {
91+
self.logit_bias.bias = bias;
92+
}
93+
}

0 commit comments

Comments
 (0)