diff --git a/src/lru_cache.py b/src/lru_cache.py index 9c6f449..3d7eb3c 100644 --- a/src/lru_cache.py +++ b/src/lru_cache.py @@ -1,68 +1,83 @@ class Node: - def __init__(self, val: int, key: int = -1) -> None: + def __init__(self, val: int, key: int) -> None: self.val = val self.key = key - self.prev: Node | None = None - self.next: Node | None = None + self.nxt: Node | None = None + self.prv: Node | None = None - def __repr__(self) -> str: + def __str__(self) -> str: return f"Node(val={self.val}, key={self.key})" class LRUCache: def __init__(self, capacity: int) -> None: - self.capacity = capacity - self.cache: dict[int, Node] = {} - - self.head = Node(-1) - self.tail = Node(-1) - - self.head.next = self.tail - self.tail.prev = self.head - - self.size = 0 - - def insert_as_head(self, node: Node) -> None: - node.next = self.head.next - self.head.next = node - node.prev = self.head + self.n = capacity - assert node.next - node.next.prev = node + self.head = Node(-1, -1) + self.tail = Node(-1, -1) - @staticmethod - def remove(node: Node) -> None: - assert node.next - assert node.prev + self.head.nxt = self.tail + self.tail.prv = self.head - node.prev.next = node.next - node.next.prev = node.prev + self.cache: dict[int, Node] = {} def get(self, key: int) -> int: if key in self.cache: node = self.cache[key] - self.remove(node) - self.insert_as_head(node) + self.touch(node) return node.val else: return -1 + def touch(self, node: Node | None) -> None: + self.cut(node) + self.append(node) + + def append(self, node: Node | None) -> None: + assert node + + prv = self.tail.prv + assert prv + prv.nxt = node + node.nxt = self.tail + self.tail.prv = node + node.prv = prv + + @staticmethod + def cut(node: Node | None) -> None: + assert node + + prv = node.prv + nxt = node.nxt + assert prv and nxt + prv.nxt = nxt + nxt.prv = prv + + def pop_left(self) -> None: + if self.head.nxt is not self.tail: + node = self.head.nxt + self.cut(node) + assert node + del self.cache[node.key] + def put(self, key: int, value: int) -> None: + # existing key, update if key in self.cache: node = self.cache[key] node.val = value - self.remove(node) - else: - node = Node(val=value, key=key) - self.cache[key] = node - self.size += 1 + self.touch(node) - if self.size > self.capacity: - oldest_node = self.tail.prev - assert oldest_node - self.remove(oldest_node) - del self.cache[oldest_node.key] - self.size -= 1 + # new key, free space available + elif self.n: + node = Node(value, key) + self.append(node) + self.cache[key] = node + self.n -= 1 - self.insert_as_head(node) + # new key, no free space + else: + node = Node(value, key) + self.pop_left() + self.append(node) + self.cache[key] = node