From 6b7fc535f9b9c20a80246279556b1aca96dbad37 Mon Sep 17 00:00:00 2001 From: Teng Huang Date: Thu, 5 Dec 2024 11:29:00 +0800 Subject: [PATCH] Add cache logic. --- mybatis/__init__.py | 3 +- mybatis/cache.py | 122 +++++++++++++++++++++++++++++++++++++++++++ mybatis/mybatis.py | 34 +++++++++++- setup.py | 7 +-- test/test_cache.py | 44 ++++++++++++++++ test/test_mybatis.py | 63 +++++++++++++++++++++- 6 files changed, 266 insertions(+), 7 deletions(-) create mode 100644 mybatis/cache.py create mode 100644 test/test_cache.py diff --git a/mybatis/__init__.py b/mybatis/__init__.py index adf7171..3b0031d 100644 --- a/mybatis/__init__.py +++ b/mybatis/__init__.py @@ -1,2 +1,3 @@ from .mapper_manager import MapperManager -from .mybatis import Mybatis \ No newline at end of file +from .mybatis import Mybatis +from .cache import Cache, CacheKey \ No newline at end of file diff --git a/mybatis/cache.py b/mybatis/cache.py new file mode 100644 index 0000000..89d4a3b --- /dev/null +++ b/mybatis/cache.py @@ -0,0 +1,122 @@ +import json +import pickle +from typing import Dict, Any, Optional + +from pympler import asizeof + +class CacheKey(object): + def __init__(self, sql, param_list): + self.sql = sql + self.param_list = param_list + + def __hash__(self): + return hash((self.sql, self.param_list)) + + def __eq__(self, other): + if not isinstance(other, CacheKey): + return False + return self.sql == other.sql and self.param_list == other.param_list + +class Cache(object): + def __init__(self, memory_limit:int): + self.memory_limit = memory_limit + self.memory_used = 0 + self.table : Dict[str, CacheNode] = {} + self.list = CacheList() + + def empty(self): + # assert self.list.head.next is self.list.tail + return len(self.table) == 0 + + def clear(self): + self.table.clear() + head = self.list.head + tail = self.list.tail + + head.next = tail + tail.prev = head + + def put(self, raw_key:CacheKey, value: Any): + key = json.dumps(raw_key.__dict__) + if key in self.table: + self.memory_used -= self.table[key].memory_usage + node = self.table[key] + else: + node = CacheNode(key, json.dumps(value)) + node.memory_usage = asizeof.asizeof(node.key) + asizeof.asizeof(node.value) + + # print("====>", node.memory_usage) + + while self.memory_used + node.memory_usage >= self.memory_limit: + to_remove_node = self.list.tail.prev + if to_remove_node is not self.list.head: + del self.table[to_remove_node.key] + self.list.remove(to_remove_node) + self.memory_used -= to_remove_node.memory_usage + else: + break + + if self.memory_used + node.memory_usage > self.memory_limit: + return + + self.table[key] = node + + self.list.move_to_head(node) + + self.memory_used += node.memory_usage + + def get(self, raw_key: CacheKey) -> Optional[Any]: + key = json.dumps(raw_key.__dict__) + if key not in self.table: + return None + node = self.table[key] + self.list.move_to_head(node) + return json.loads(node.value) + + # def dump(self): + # for node in self.list.traverse(): + # print(node.key, node.value, node.memory_usage, asizeof.asizeof(node.key) + asizeof.asizeof(node.value)) + def traverse(self): + node = self.list.head.next + while node is not self.list.tail: + yield node.key, json.loads(node.value), node.memory_usage + node = node.next + +class CacheNode: + def __init__(self, key : Any, value : Any): + self.prev = None + self.next = None + self.memory_usage = asizeof.asizeof(key) + asizeof.asizeof(value) + self.key = key + self.value = value + +class CacheList: + def __init__(self): + self.head = CacheNode(None, None) + self.tail = CacheNode(None, None) + self.head.next = self.tail + self.tail.prev = self.head + + @staticmethod + def remove(node:CacheNode): + if node.prev is None: + return + + prev = node.prev + next = node.next + + prev.next = next + next.prev = prev + + node.prev = None + node.next = None + + def insert_after_head(self, node:CacheNode): + node.next = self.head.next + self.head.next = node + node.next.prev = node + node.prev = self.head + + def move_to_head(self, node:CacheNode): + CacheList.remove(node) + self.insert_after_head(node) diff --git a/mybatis/mybatis.py b/mybatis/mybatis.py index 46cfa20..7e17b0b 100644 --- a/mybatis/mybatis.py +++ b/mybatis/mybatis.py @@ -1,13 +1,20 @@ from typing import Optional, Dict, List from .mapper_manager import MapperManager +from .cache import Cache, CacheKey + import os class Mybatis(object): - def __init__(self, conn, mapper_path:str): + def __init__(self, conn, mapper_path:str, cache_memory_limit:Optional[int]=None): self.conn = conn self.mapper_manager = MapperManager() + if cache_memory_limit is not None: + self.cache = Cache(cache_memory_limit) + else: + self.cache = Cache(0) + mapper_file_name_l = [name for name in os.listdir(mapper_path) if name.endswith(".xml")] for file_name in mapper_file_name_l: full_path = os.path.join(mapper_path, file_name) @@ -15,6 +22,11 @@ def __init__(self, conn, mapper_path:str): def select_one(self, id:str, params:dict) -> Optional[Dict]: sql, param_list = self.mapper_manager.select(id, params) + if self.cache is not None: + res = self.cache.get(CacheKey(sql, param_list)) + if res is not None: + return res + with self.conn.cursor(prepared=True) as cursor: cursor.execute(sql, param_list) ret = cursor.fetchone() @@ -24,10 +36,18 @@ def select_one(self, id:str, params:dict) -> Optional[Dict]: res = {} for idx, item in enumerate(column_name): res[item] = ret[idx] + + if self.cache is not None: + self.cache.put(CacheKey(sql, param_list), res) return res def select_many(self, id:str, params:dict) -> Optional[List[Dict]]: sql, param_list = self.mapper_manager.select(id, params) + if self.cache is not None: + res = self.cache.get(CacheKey(sql, param_list)) + if res is not None: + return res + with self.conn.cursor(prepared=True) as cursor: cursor.execute(sql, param_list) ret = cursor.fetchall() @@ -40,6 +60,9 @@ def select_many(self, id:str, params:dict) -> Optional[List[Dict]]: for idx, item in enumerate(column_name): d[item] = row[idx] res_list.append(d) + + if self.cache is not None: + self.cache.put(CacheKey(sql, param_list), res_list) return res_list def update(self, id:str, params:dict) -> int: @@ -49,6 +72,9 @@ def update(self, id:str, params:dict) -> int: :return: affected rows ''' sql, param_list = self.mapper_manager.update(id, params) + + res = self.cache.clear() + with self.conn.cursor(prepared=True) as cursor: cursor.execute(sql, param_list) affected_rows = cursor.rowcount @@ -62,6 +88,9 @@ def delete(self, id:str, params:dict) -> int: :return: affected rows ''' sql, param_list = self.mapper_manager.delete(id, params) + + res = self.cache.clear() + with self.conn.cursor(prepared=True) as cursor: cursor.execute(sql, param_list) affected_rows = cursor.rowcount @@ -75,6 +104,9 @@ def insert(self, id:str, params:dict) -> int: :return: last auto incremented row id ''' sql, param_list = self.mapper_manager.insert(id, params) + + res = self.cache.clear() + with self.conn.cursor(prepared=True) as cursor: cursor.execute(sql, param_list) self.conn.commit() diff --git a/setup.py b/setup.py index ea82554..6c4dc8b 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='mybatis', - version='0.0.2', + version='0.0.3', description='A python ORM like mybatis.', long_description=open('README.md').read(), long_description_content_type='text/markdown', # 如果你使用的是Markdown格式的README @@ -11,11 +11,12 @@ url='https://github.com/ralgond/mybatis-py', packages=find_packages(), install_requires=[ - 'mysql-connector-python>=9.0.0' + 'mysql-connector-python>=9.0.0', + 'Pympler>=1.1' ], classifiers=[ 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: MIT License', + 'License :: OSI Approved :: Apache License', 'Operating System :: OS Independent', ], python_requires='>=3.8', diff --git a/test/test_cache.py b/test/test_cache.py new file mode 100644 index 0000000..a575e82 --- /dev/null +++ b/test/test_cache.py @@ -0,0 +1,44 @@ +import pytest +from mybatis import Cache, CacheKey + +def test_basic(): + cache = Cache(memory_limit=555) # 50MB + cache.put(CacheKey("a", [1, 'a', None]), [{"a1": 1}, {"a2": 2}]) + cache.put(CacheKey("b", [2, 'b', None]), "2") + cache.put(CacheKey("c", [3, 'c', None]), "3") + cache.put(CacheKey("d", [4, 'd', None]), None) + + assert cache.get(CacheKey('a', [1, 'a', None])) == None + assert cache.get(CacheKey('b', [2, 'b', None])) == '2' + + l = [] + for key, value, memory_usage in cache.traverse(): + l.append((key, value, memory_usage, type(value))) + + assert len(l) == 3 + assert l[0][0] == '{"sql": "b", "param_list": [2, "b", null]}' + assert l[0][1] == '2' + + assert l[1][0] == '{"sql": "d", "param_list": [4, "d", null]}' + assert l[1][1] == None + + assert l[2][0] == '{"sql": "c", "param_list": [3, "c", null]}' + assert l[2][1] == '3' + + cache.put(CacheKey("e", [5, 'e', None]), "5") + + + l = [] + for key, value, memory_usage in cache.traverse(): + l.append((key, value, memory_usage, type(value))) + + assert len(l) == 3 + + assert l[0][0] == '{"sql": "e", "param_list": [5, "e", null]}' + assert l[0][1] == '5' + + assert l[1][0] == '{"sql": "b", "param_list": [2, "b", null]}' + assert l[1][1] == '2' + + assert l[2][0] == '{"sql": "d", "param_list": [4, "d", null]}' + assert l[2][1] == None diff --git a/test/test_mybatis.py b/test/test_mybatis.py index 81d38b7..c1079b1 100644 --- a/test/test_mybatis.py +++ b/test/test_mybatis.py @@ -34,7 +34,7 @@ def db_connection(): connection.close() def test_select_one(db_connection): - mb = Mybatis(db_connection, "mapper") + mb = Mybatis(db_connection, "mapper", cache_memory_limit=50*1024*1024) ret = mb.select_one('testBasic', {}) assert ret is not None assert len(ret) == 4 @@ -43,13 +43,36 @@ def test_select_one(db_connection): assert ret['category'] == 'A' assert ret['price'] == 100 + ret = mb.select_one('testBasic', {}) + assert ret is not None + assert len(ret) == 4 + assert ret['id'] == 1 + assert ret['name'] == 'Alice' + assert ret['category'] == 'A' + assert ret['price'] == 100 + + + def test_select_one_none(db_connection): mb = Mybatis(db_connection, "mapper") ret = mb.select_one('testBasicNone', {}) assert ret is None def test_select_many(db_connection): - mb = Mybatis(db_connection, "mapper") + mb = Mybatis(db_connection, "mapper", cache_memory_limit=50*1024*1024) + ret = mb.select_many('testBasicMany', {}) + assert ret is not None + assert isinstance(ret, list) + assert len(ret) == 2 + assert ret[0]['id'] == 1 + assert ret[0]['name'] == 'Alice' + assert ret[0]['category'] == 'A' + assert ret[0]['price'] == 100 + assert ret[1]['id'] == 2 + assert ret[1]['name'] == 'Bob' + assert ret[1]['category'] == 'B' + assert ret[1]['price'] == 200 + ret = mb.select_many('testBasicMany', {}) assert ret is not None assert isinstance(ret, list) @@ -70,7 +93,14 @@ def test_select_many_none(db_connection): def test_update(db_connection): mb = Mybatis(db_connection, "mapper") + mb.select_one('testBasic', {}) + + assert mb.cache.empty() is True + ret = mb.update("testUpdate", {"name":"Candy", "id":2}) + + assert mb.cache.empty() is True + assert ret == 1 ret = mb.select_many('testBasicMany', {}) assert ret is not None @@ -85,6 +115,35 @@ def test_update(db_connection): assert ret[1]['category'] == 'B' assert ret[1]['price'] == 200 + assert mb.cache.empty() is True + + +def test_update_with_cache(db_connection): + mb = Mybatis(db_connection, "mapper", cache_memory_limit=50*1024*1024) + mb.select_one('testBasic', {}) + + assert mb.cache.empty() is False + + ret = mb.update("testUpdate", {"name":"Candy", "id":2}) + + assert mb.cache.empty() is True + + assert ret == 1 + ret = mb.select_many('testBasicMany', {}) + assert ret is not None + assert isinstance(ret, list) + assert len(ret) == 2 + assert ret[0]['id'] == 1 + assert ret[0]['name'] == 'Alice' + assert ret[0]['category'] == 'A' + assert ret[0]['price'] == 100 + assert ret[1]['id'] == 2 + assert ret[1]['name'] == 'Candy' + assert ret[1]['category'] == 'B' + assert ret[1]['price'] == 200 + + assert mb.cache.empty() is False + def test_delete(db_connection): mb = Mybatis(db_connection, "mapper") ret = mb.delete("testDelete", {"id":2})