From a27435c5572e6585bd340b7b43f564becb726507 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 9 Jan 2025 22:12:23 +0100 Subject: [PATCH] reduce cache size automatically --- .../exp2024_04_23_baselines/ctc_recog_ext.py | 11 ++++++++--- users/zeyer/utils/lru_cache.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py index 18309ddbb..c7c89d334 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py @@ -516,6 +516,11 @@ class FlashlightLMState: label_seq: List[int] prev_state: LMState + # Use LRU cache for the LM states (on GPU) and log probs. + # Note that additionally to the cache size limit here, + # we free more when we run out of CUDA memory. + start_lru_cache_size = 1024 + class FlashlightLM(LM): def __init__(self): super().__init__() @@ -525,9 +530,7 @@ def __init__(self): self._count_recalc_whole_seq = 0 self._recent_debug_log_time = -sys.maxsize - # Use LRU cache. Note that additionally to the max_size here, - # we free more when we run out of CUDA memory. - @lru_cache(maxsize=1024) + @lru_cache(maxsize=start_lru_cache_size) def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]: """ :return: LM state, log probs [Vocab] @@ -550,6 +553,7 @@ def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]: count_pop += 1 if count_pop > 0: print(f"Pop {count_pop} from cache, mem usage {dev_s}: {' '.join(_collect_mem_stats())}") + self._calc_next_lm_state.cache_set_maxsize(self._calc_next_lm_state.cache_len()) if prev_lm_state is not None or lm_initial_state is None: # We have the prev state, or there is no state at all. @@ -583,6 +587,7 @@ def start(self, start_with_nothing: bool): self._recent_debug_log_time = -sys.maxsize self.mapping_states.clear() self._calc_next_lm_state.cache_clear() + self._calc_next_lm_state.cache_set_maxsize(start_lru_cache_size) state = LMState() self.mapping_states[state] = FlashlightLMState(label_seq=[model.bos_idx], prev_state=state) return state diff --git a/users/zeyer/utils/lru_cache.py b/users/zeyer/utils/lru_cache.py index 18a1e8abe..6d669df71 100644 --- a/users/zeyer/utils/lru_cache.py +++ b/users/zeyer/utils/lru_cache.py @@ -27,6 +27,7 @@ def lru_cache(maxsize: int = 128, typed: bool = False): with f.cache_info(). Clear the cache and statistics with f.cache_clear(). Remove the oldest entry from the cache with f.cache_pop_oldest(). + Set the maximum cache size to a new value with f.cache_set_maxsize(new_maxsize). Access the underlying function with f.__wrapped__. See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) @@ -178,12 +179,22 @@ def cache_pop_oldest(*, fallback=not_specified): full = False return oldvalue + def cache_set_maxsize(new_maxsize: int): + nonlocal maxsize, full + assert new_maxsize > 0 + with lock: + maxsize = new_maxsize + while cache_len() > maxsize: + cache_pop_oldest() + full = cache_len() >= maxsize + wrapper.cache_info = cache_info wrapper.cache_clear = cache_clear wrapper.cache_parameters = cache_parameters wrapper.cache_peek = cache_peek wrapper.cache_len = cache_len wrapper.cache_pop_oldest = cache_pop_oldest + wrapper.cache_set_maxsize = cache_set_maxsize update_wrapper(wrapper, user_function)