diff --git a/speecht-cli b/speecht-cli index ec6417b..a2661f3 100755 --- a/speecht-cli +++ b/speecht-cli @@ -77,6 +77,12 @@ class CLI: 'Specify a directory containing `kenlm-model.binary`, ' '`vocabulary` and `trie`. ' 'Language model must be binary format with probing hash table.') + parser.add_argument('--lm-weight', dest='lm_weight', type=float, default=0.8, + help='The weight multiplied with the language model score') + parser.add_argument('--word-count-weight', dest='word_count_weight', type=float, default=0.0, + help='The weight added for each word') + parser.add_argument('--valid-word-count-weight', dest='valid_word_count_weight', type=float, default=2.3, + help='The weight added for each in vocabulary word') def _add_evaluation_parser(self): evaluation_parser = self.subparsers.add_parser('evaluate', help='Evaluate the development or test set.', diff --git a/speecht/speech_model.py b/speecht/speech_model.py index 265a9b0..f4691e6 100644 --- a/speecht/speech_model.py +++ b/speecht/speech_model.py @@ -81,17 +81,22 @@ def add_training_ops(self, learning_rate: bool = 1e-3, learning_rate_decay_facto self.update = optimizer.apply_gradients(zip(clipped_gradients, trainables), global_step=self.global_step, name='apply_gradients') - def add_decoding_ops(self, language_model: str = None): + def add_decoding_ops(self, language_model: str = None, lm_weight: float = 0.8, word_count_weight: float = 0.0, + valid_word_count_weight: float = 2.3): """ Add the ops for decoding - +j Args: language_model: the file path to the language model to use for beam search decoding or None + word_count_weight: The weight added for each added word + valid_word_count_weight: The weight added for each in vocabulary word + lm_weight: The weight multiplied with the language model scoring """ with tf.name_scope('decoding'): - self.lm_weight = tf.placeholder_with_default(0.8, shape=(), name='language_model_weight') - self.word_count_weight = tf.placeholder_with_default(0.0, shape=(), name='word_count_weight') - self.valid_word_count_weight = tf.placeholder_with_default(2.3, shape=(), name='valid_word_count_weight') + self.lm_weight = tf.placeholder_with_default(lm_weight, shape=(), name='language_model_weight') + self.word_count_weight = tf.placeholder_with_default(word_count_weight, shape=(), name='word_count_weight') + self.valid_word_count_weight = tf.placeholder_with_default(valid_word_count_weight, shape=(), + name='valid_word_count_weight') if language_model: self.softmaxed = tf.log(tf.nn.softmax(self.logits, name='softmax') + 1e-8) / math.log(10) @@ -304,7 +309,10 @@ def create_default_model(flags, input_size: int, speech_input: BaseInputLoader) model.add_decoding_ops() else: model.add_training_ops() - model.add_decoding_ops(language_model=flags.language_model) + model.add_decoding_ops(language_model=flags.language_model, + lm_weight=flags.lm_weight, + word_count_weight=flags.word_count_weight, + valid_word_count_weight=flags.valid_word_count_weight) model.finalize(log_dir=flags.log_dir, run_name=flags.run_name,