Skip to content

Commit 74cc6f6

Browse files
committed
Python - Simplify padding interface
1 parent d1e59e0 commit 74cc6f6

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

bindings/python/src/tokenizer.rs

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -139,29 +139,45 @@ impl Tokenizer {
139139
self.tokenizer.with_truncation(None);
140140
}
141141

142-
fn with_padding(
143-
&mut self,
144-
size: Option<usize>,
145-
direction: &str,
146-
pad_id: u32,
147-
pad_type_id: u32,
148-
pad_token: &str,
149-
) -> PyResult<()> {
150-
let strategy = if let Some(size) = size {
151-
PaddingStrategy::Fixed(size)
142+
#[args(kwargs = "**")]
143+
fn with_padding(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
144+
let mut direction = PaddingDirection::Right;
145+
let mut pad_id: u32 = 0;
146+
let mut pad_type_id: u32 = 0;
147+
let mut pad_token = String::from("[PAD]");
148+
let mut max_length: Option<usize> = None;
149+
150+
if let Some(kwargs) = kwargs {
151+
for (key, value) in kwargs {
152+
let key: &str = key.extract()?;
153+
match key {
154+
"direction" => {
155+
let value: &str = value.extract()?;
156+
direction = match value {
157+
"left" => Ok(PaddingDirection::Left),
158+
"right" => Ok(PaddingDirection::Right),
159+
other => Err(PyError(format!(
160+
"Unknown `direction`: `{}`. Use \
161+
one of `left` or `right`",
162+
other
163+
))
164+
.into_pyerr()),
165+
}?;
166+
}
167+
"pad_id" => pad_id = value.extract()?,
168+
"pad_type_id" => pad_type_id = value.extract()?,
169+
"pad_token" => pad_token = value.extract()?,
170+
"max_length" => max_length = value.extract()?,
171+
_ => println!("Ignored unknown kwarg option {}", key),
172+
}
173+
}
174+
}
175+
176+
let strategy = if let Some(max_length) = max_length {
177+
PaddingStrategy::Fixed(max_length)
152178
} else {
153179
PaddingStrategy::BatchLongest
154180
};
155-
let direction = match direction {
156-
"left" => Ok(PaddingDirection::Left),
157-
"right" => Ok(PaddingDirection::Right),
158-
other => Err(PyError(format!(
159-
"Unknown `direction`: `{}`. Use \
160-
one of `left` or `right`",
161-
other
162-
))
163-
.into_pyerr()),
164-
}?;
165181

166182
self.tokenizer.with_padding(Some(PaddingParams {
167183
strategy,

0 commit comments

Comments
 (0)