Skip to content

Commit

Permalink
Handle nil attributes when loading generation config
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 26, 2024
1 parent e2f5b4c commit 3bca4c5
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions lib/bumblebee/text/generation_config.ex
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ defmodule Bumblebee.Text.GenerationConfig do

opts =
convert!(data,
max_new_tokens: {"max_new_tokens", number()},
min_new_tokens: {"min_new_tokens", number()},
max_length: {"max_length", number()},
min_length: {"min_length", number()},
max_new_tokens: {"max_new_tokens", optional(number())},
min_new_tokens: {"min_new_tokens", optional(number())},
max_length: {"max_length", optional(number())},
min_length: {"min_length", optional(number())},
decoder_start_token_id: {"decoder_start_token_id", optional(number())},
bos_token_id: {"bos_token_id", optional(number())},
eos_token_id: {"eos_token_id", optional(number())},
Expand All @@ -306,10 +306,11 @@ defmodule Bumblebee.Text.GenerationConfig do
data
|> convert!(
sample: {"do_sample", boolean()},
top_k: {"top_k", number()},
top_p: {"top_p", number()},
alpha: {"penalty_alpha", number()}
top_k: {"top_k", optional(number())},
top_p: {"top_p", optional(number())},
alpha: {"penalty_alpha", optional(number())}
)
|> Enum.reject(fn {_key, value} -> value == nil end)
|> Map.new()
|> case do
%{sample: true} = opts ->
Expand Down

0 comments on commit 3bca4c5

Please sign in to comment.