diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py index 0ec72e8ff..948d125d8 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py @@ -690,12 +690,6 @@ def score(self, state: LMState, token_index: int): Returns: (LMState, float): pair of (new state, score for the current word) """ - if len(self.mapping_states) % 1_000_000 == 0: - snapshot = tracemalloc.take_snapshot() - top_stats = snapshot.compare_to(snapshot_start, "lineno") - print(f"[ {len(self.mapping_states)} states, top 100 mallocs ]") - for stat in top_stats[:100]: - print(stat) state_ = self.mapping_states[state] if time.monotonic() - self._recent_debug_log_time > 1: print( @@ -712,6 +706,14 @@ def score(self, state: LMState, token_index: int): self.mapping_states[outstate] = FlashlightLMState( label_seq=state_.label_seq + [token_index], prev_state=state ) + + if len(self.mapping_states) % 1_000_000 == 0: + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.compare_to(snapshot_start, "lineno") + print(f"[ {len(self.mapping_states)} states, top 100 mallocs ]") + for stat in top_stats[:100]: + print(stat) + _, log_probs_raw = self._calc_next_lm_state(state) return outstate, log_probs_raw[token_index]