Skip to content

Commit

Permalink
Language model weights can be specified on cli
Browse files Browse the repository at this point in the history
  • Loading branch information
louiskirsch committed Apr 30, 2017
1 parent 21ef25b commit e9b9d80
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
6 changes: 6 additions & 0 deletions speecht-cli
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ class CLI:
'Specify a directory containing `kenlm-model.binary`, '
'`vocabulary` and `trie`. '
'Language model must be binary format with probing hash table.')
parser.add_argument('--lm-weight', dest='lm_weight', type=float, default=0.8,
help='The weight multiplied with the language model score')
parser.add_argument('--word-count-weight', dest='word_count_weight', type=float, default=0.0,
help='The weight added for each word')
parser.add_argument('--valid-word-count-weight', dest='valid_word_count_weight', type=float, default=2.3,
help='The weight added for each in vocabulary word')

def _add_evaluation_parser(self):
evaluation_parser = self.subparsers.add_parser('evaluate', help='Evaluate the development or test set.',
Expand Down
20 changes: 14 additions & 6 deletions speecht/speech_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,22 @@ def add_training_ops(self, learning_rate: bool = 1e-3, learning_rate_decay_facto
self.update = optimizer.apply_gradients(zip(clipped_gradients, trainables),
global_step=self.global_step, name='apply_gradients')

def add_decoding_ops(self, language_model: str = None):
def add_decoding_ops(self, language_model: str = None, lm_weight: float = 0.8, word_count_weight: float = 0.0,
valid_word_count_weight: float = 2.3):
"""
Add the ops for decoding
j
Args:
language_model: the file path to the language model to use for beam search decoding or None
word_count_weight: The weight added for each added word
valid_word_count_weight: The weight added for each in vocabulary word
lm_weight: The weight multiplied with the language model scoring
"""
with tf.name_scope('decoding'):
self.lm_weight = tf.placeholder_with_default(0.8, shape=(), name='language_model_weight')
self.word_count_weight = tf.placeholder_with_default(0.0, shape=(), name='word_count_weight')
self.valid_word_count_weight = tf.placeholder_with_default(2.3, shape=(), name='valid_word_count_weight')
self.lm_weight = tf.placeholder_with_default(lm_weight, shape=(), name='language_model_weight')
self.word_count_weight = tf.placeholder_with_default(word_count_weight, shape=(), name='word_count_weight')
self.valid_word_count_weight = tf.placeholder_with_default(valid_word_count_weight, shape=(),
name='valid_word_count_weight')

if language_model:
self.softmaxed = tf.log(tf.nn.softmax(self.logits, name='softmax') + 1e-8) / math.log(10)
Expand Down Expand Up @@ -304,7 +309,10 @@ def create_default_model(flags, input_size: int, speech_input: BaseInputLoader)
model.add_decoding_ops()
else:
model.add_training_ops()
model.add_decoding_ops(language_model=flags.language_model)
model.add_decoding_ops(language_model=flags.language_model,
lm_weight=flags.lm_weight,
word_count_weight=flags.word_count_weight,
valid_word_count_weight=flags.valid_word_count_weight)

model.finalize(log_dir=flags.log_dir,
run_name=flags.run_name,
Expand Down

0 comments on commit e9b9d80

Please sign in to comment.