Skip to content

Commit

Permalink
reduce cache size automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 9, 2025
1 parent e579c74 commit a27435c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
11 changes: 8 additions & 3 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ class FlashlightLMState:
label_seq: List[int]
prev_state: LMState

# Use LRU cache for the LM states (on GPU) and log probs.
# Note that additionally to the cache size limit here,
# we free more when we run out of CUDA memory.
start_lru_cache_size = 1024

class FlashlightLM(LM):
def __init__(self):
super().__init__()
Expand All @@ -525,9 +530,7 @@ def __init__(self):
self._count_recalc_whole_seq = 0
self._recent_debug_log_time = -sys.maxsize

# Use LRU cache. Note that additionally to the max_size here,
# we free more when we run out of CUDA memory.
@lru_cache(maxsize=1024)
@lru_cache(maxsize=start_lru_cache_size)
def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]:
"""
:return: LM state, log probs [Vocab]
Expand All @@ -550,6 +553,7 @@ def _calc_next_lm_state(self, state: LMState) -> Tuple[Any, torch.Tensor]:
count_pop += 1
if count_pop > 0:
print(f"Pop {count_pop} from cache, mem usage {dev_s}: {' '.join(_collect_mem_stats())}")
self._calc_next_lm_state.cache_set_maxsize(self._calc_next_lm_state.cache_len())

if prev_lm_state is not None or lm_initial_state is None:
# We have the prev state, or there is no state at all.
Expand Down Expand Up @@ -583,6 +587,7 @@ def start(self, start_with_nothing: bool):
self._recent_debug_log_time = -sys.maxsize
self.mapping_states.clear()
self._calc_next_lm_state.cache_clear()
self._calc_next_lm_state.cache_set_maxsize(start_lru_cache_size)
state = LMState()
self.mapping_states[state] = FlashlightLMState(label_seq=[model.bos_idx], prev_state=state)
return state
Expand Down
11 changes: 11 additions & 0 deletions users/zeyer/utils/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def lru_cache(maxsize: int = 128, typed: bool = False):
with f.cache_info().
Clear the cache and statistics with f.cache_clear().
Remove the oldest entry from the cache with f.cache_pop_oldest().
Set the maximum cache size to a new value with f.cache_set_maxsize(new_maxsize).
Access the underlying function with f.__wrapped__.
See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
Expand Down Expand Up @@ -178,12 +179,22 @@ def cache_pop_oldest(*, fallback=not_specified):
full = False
return oldvalue

def cache_set_maxsize(new_maxsize: int):
nonlocal maxsize, full
assert new_maxsize > 0
with lock:
maxsize = new_maxsize
while cache_len() > maxsize:
cache_pop_oldest()
full = cache_len() >= maxsize

wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
wrapper.cache_parameters = cache_parameters
wrapper.cache_peek = cache_peek
wrapper.cache_len = cache_len
wrapper.cache_pop_oldest = cache_pop_oldest
wrapper.cache_set_maxsize = cache_set_maxsize

update_wrapper(wrapper, user_function)

Expand Down

0 comments on commit a27435c

Please sign in to comment.