Skip to content

Commit

Permalink
Add cache logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
ralgond committed Dec 5, 2024
1 parent 5e4608b commit 6b7fc53
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 7 deletions.
3 changes: 2 additions & 1 deletion mybatis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .mapper_manager import MapperManager
from .mybatis import Mybatis
from .mybatis import Mybatis
from .cache import Cache, CacheKey
122 changes: 122 additions & 0 deletions mybatis/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import json
import pickle
from typing import Dict, Any, Optional

from pympler import asizeof

class CacheKey(object):
def __init__(self, sql, param_list):
self.sql = sql
self.param_list = param_list

def __hash__(self):
return hash((self.sql, self.param_list))

def __eq__(self, other):
if not isinstance(other, CacheKey):
return False
return self.sql == other.sql and self.param_list == other.param_list

class Cache(object):
def __init__(self, memory_limit:int):
self.memory_limit = memory_limit
self.memory_used = 0
self.table : Dict[str, CacheNode] = {}
self.list = CacheList()

def empty(self):
# assert self.list.head.next is self.list.tail
return len(self.table) == 0

def clear(self):
self.table.clear()
head = self.list.head
tail = self.list.tail

head.next = tail
tail.prev = head

def put(self, raw_key:CacheKey, value: Any):
key = json.dumps(raw_key.__dict__)
if key in self.table:
self.memory_used -= self.table[key].memory_usage
node = self.table[key]
else:
node = CacheNode(key, json.dumps(value))
node.memory_usage = asizeof.asizeof(node.key) + asizeof.asizeof(node.value)

# print("====>", node.memory_usage)

while self.memory_used + node.memory_usage >= self.memory_limit:
to_remove_node = self.list.tail.prev
if to_remove_node is not self.list.head:
del self.table[to_remove_node.key]
self.list.remove(to_remove_node)
self.memory_used -= to_remove_node.memory_usage
else:
break

if self.memory_used + node.memory_usage > self.memory_limit:
return

self.table[key] = node

self.list.move_to_head(node)

self.memory_used += node.memory_usage

def get(self, raw_key: CacheKey) -> Optional[Any]:
key = json.dumps(raw_key.__dict__)
if key not in self.table:
return None
node = self.table[key]
self.list.move_to_head(node)
return json.loads(node.value)

# def dump(self):
# for node in self.list.traverse():
# print(node.key, node.value, node.memory_usage, asizeof.asizeof(node.key) + asizeof.asizeof(node.value))
def traverse(self):
node = self.list.head.next
while node is not self.list.tail:
yield node.key, json.loads(node.value), node.memory_usage
node = node.next

class CacheNode:
def __init__(self, key : Any, value : Any):
self.prev = None
self.next = None
self.memory_usage = asizeof.asizeof(key) + asizeof.asizeof(value)
self.key = key
self.value = value

class CacheList:
def __init__(self):
self.head = CacheNode(None, None)
self.tail = CacheNode(None, None)
self.head.next = self.tail
self.tail.prev = self.head

@staticmethod
def remove(node:CacheNode):
if node.prev is None:
return

prev = node.prev
next = node.next

prev.next = next
next.prev = prev

node.prev = None
node.next = None

def insert_after_head(self, node:CacheNode):
node.next = self.head.next
self.head.next = node
node.next.prev = node
node.prev = self.head

def move_to_head(self, node:CacheNode):
CacheList.remove(node)
self.insert_after_head(node)
34 changes: 33 additions & 1 deletion mybatis/mybatis.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from typing import Optional, Dict, List

from .mapper_manager import MapperManager
from .cache import Cache, CacheKey

import os

class Mybatis(object):
def __init__(self, conn, mapper_path:str):
def __init__(self, conn, mapper_path:str, cache_memory_limit:Optional[int]=None):
self.conn = conn
self.mapper_manager = MapperManager()

if cache_memory_limit is not None:
self.cache = Cache(cache_memory_limit)
else:
self.cache = Cache(0)

mapper_file_name_l = [name for name in os.listdir(mapper_path) if name.endswith(".xml")]
for file_name in mapper_file_name_l:
full_path = os.path.join(mapper_path, file_name)
self.mapper_manager.read_mapper_xml_file(full_path)

def select_one(self, id:str, params:dict) -> Optional[Dict]:
sql, param_list = self.mapper_manager.select(id, params)
if self.cache is not None:
res = self.cache.get(CacheKey(sql, param_list))
if res is not None:
return res

with self.conn.cursor(prepared=True) as cursor:
cursor.execute(sql, param_list)
ret = cursor.fetchone()
Expand All @@ -24,10 +36,18 @@ def select_one(self, id:str, params:dict) -> Optional[Dict]:
res = {}
for idx, item in enumerate(column_name):
res[item] = ret[idx]

if self.cache is not None:
self.cache.put(CacheKey(sql, param_list), res)
return res

def select_many(self, id:str, params:dict) -> Optional[List[Dict]]:
sql, param_list = self.mapper_manager.select(id, params)
if self.cache is not None:
res = self.cache.get(CacheKey(sql, param_list))
if res is not None:
return res

with self.conn.cursor(prepared=True) as cursor:
cursor.execute(sql, param_list)
ret = cursor.fetchall()
Expand All @@ -40,6 +60,9 @@ def select_many(self, id:str, params:dict) -> Optional[List[Dict]]:
for idx, item in enumerate(column_name):
d[item] = row[idx]
res_list.append(d)

if self.cache is not None:
self.cache.put(CacheKey(sql, param_list), res_list)
return res_list

def update(self, id:str, params:dict) -> int:
Expand All @@ -49,6 +72,9 @@ def update(self, id:str, params:dict) -> int:
:return: affected rows
'''
sql, param_list = self.mapper_manager.update(id, params)

res = self.cache.clear()

with self.conn.cursor(prepared=True) as cursor:
cursor.execute(sql, param_list)
affected_rows = cursor.rowcount
Expand All @@ -62,6 +88,9 @@ def delete(self, id:str, params:dict) -> int:
:return: affected rows
'''
sql, param_list = self.mapper_manager.delete(id, params)

res = self.cache.clear()

with self.conn.cursor(prepared=True) as cursor:
cursor.execute(sql, param_list)
affected_rows = cursor.rowcount
Expand All @@ -75,6 +104,9 @@ def insert(self, id:str, params:dict) -> int:
:return: last auto incremented row id
'''
sql, param_list = self.mapper_manager.insert(id, params)

res = self.cache.clear()

with self.conn.cursor(prepared=True) as cursor:
cursor.execute(sql, param_list)
self.conn.commit()
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='mybatis',
version='0.0.2',
version='0.0.3',
description='A python ORM like mybatis.',
long_description=open('README.md').read(),
long_description_content_type='text/markdown', # 如果你使用的是Markdown格式的README
Expand All @@ -11,11 +11,12 @@
url='https://github.com/ralgond/mybatis-py',
packages=find_packages(),
install_requires=[
'mysql-connector-python>=9.0.0'
'mysql-connector-python>=9.0.0',
'Pympler>=1.1'
],
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'License :: OSI Approved :: Apache License',
'Operating System :: OS Independent',
],
python_requires='>=3.8',
Expand Down
44 changes: 44 additions & 0 deletions test/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
from mybatis import Cache, CacheKey

def test_basic():
cache = Cache(memory_limit=555) # 50MB
cache.put(CacheKey("a", [1, 'a', None]), [{"a1": 1}, {"a2": 2}])
cache.put(CacheKey("b", [2, 'b', None]), "2")
cache.put(CacheKey("c", [3, 'c', None]), "3")
cache.put(CacheKey("d", [4, 'd', None]), None)

assert cache.get(CacheKey('a', [1, 'a', None])) == None
assert cache.get(CacheKey('b', [2, 'b', None])) == '2'

l = []
for key, value, memory_usage in cache.traverse():
l.append((key, value, memory_usage, type(value)))

assert len(l) == 3
assert l[0][0] == '{"sql": "b", "param_list": [2, "b", null]}'
assert l[0][1] == '2'

assert l[1][0] == '{"sql": "d", "param_list": [4, "d", null]}'
assert l[1][1] == None

assert l[2][0] == '{"sql": "c", "param_list": [3, "c", null]}'
assert l[2][1] == '3'

cache.put(CacheKey("e", [5, 'e', None]), "5")


l = []
for key, value, memory_usage in cache.traverse():
l.append((key, value, memory_usage, type(value)))

assert len(l) == 3

assert l[0][0] == '{"sql": "e", "param_list": [5, "e", null]}'
assert l[0][1] == '5'

assert l[1][0] == '{"sql": "b", "param_list": [2, "b", null]}'
assert l[1][1] == '2'

assert l[2][0] == '{"sql": "d", "param_list": [4, "d", null]}'
assert l[2][1] == None
63 changes: 61 additions & 2 deletions test/test_mybatis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def db_connection():
connection.close()

def test_select_one(db_connection):
mb = Mybatis(db_connection, "mapper")
mb = Mybatis(db_connection, "mapper", cache_memory_limit=50*1024*1024)
ret = mb.select_one('testBasic', {})
assert ret is not None
assert len(ret) == 4
Expand All @@ -43,13 +43,36 @@ def test_select_one(db_connection):
assert ret['category'] == 'A'
assert ret['price'] == 100

ret = mb.select_one('testBasic', {})
assert ret is not None
assert len(ret) == 4
assert ret['id'] == 1
assert ret['name'] == 'Alice'
assert ret['category'] == 'A'
assert ret['price'] == 100



def test_select_one_none(db_connection):
mb = Mybatis(db_connection, "mapper")
ret = mb.select_one('testBasicNone', {})
assert ret is None

def test_select_many(db_connection):
mb = Mybatis(db_connection, "mapper")
mb = Mybatis(db_connection, "mapper", cache_memory_limit=50*1024*1024)
ret = mb.select_many('testBasicMany', {})
assert ret is not None
assert isinstance(ret, list)
assert len(ret) == 2
assert ret[0]['id'] == 1
assert ret[0]['name'] == 'Alice'
assert ret[0]['category'] == 'A'
assert ret[0]['price'] == 100
assert ret[1]['id'] == 2
assert ret[1]['name'] == 'Bob'
assert ret[1]['category'] == 'B'
assert ret[1]['price'] == 200

ret = mb.select_many('testBasicMany', {})
assert ret is not None
assert isinstance(ret, list)
Expand All @@ -70,7 +93,14 @@ def test_select_many_none(db_connection):

def test_update(db_connection):
mb = Mybatis(db_connection, "mapper")
mb.select_one('testBasic', {})

assert mb.cache.empty() is True

ret = mb.update("testUpdate", {"name":"Candy", "id":2})

assert mb.cache.empty() is True

assert ret == 1
ret = mb.select_many('testBasicMany', {})
assert ret is not None
Expand All @@ -85,6 +115,35 @@ def test_update(db_connection):
assert ret[1]['category'] == 'B'
assert ret[1]['price'] == 200

assert mb.cache.empty() is True


def test_update_with_cache(db_connection):
mb = Mybatis(db_connection, "mapper", cache_memory_limit=50*1024*1024)
mb.select_one('testBasic', {})

assert mb.cache.empty() is False

ret = mb.update("testUpdate", {"name":"Candy", "id":2})

assert mb.cache.empty() is True

assert ret == 1
ret = mb.select_many('testBasicMany', {})
assert ret is not None
assert isinstance(ret, list)
assert len(ret) == 2
assert ret[0]['id'] == 1
assert ret[0]['name'] == 'Alice'
assert ret[0]['category'] == 'A'
assert ret[0]['price'] == 100
assert ret[1]['id'] == 2
assert ret[1]['name'] == 'Candy'
assert ret[1]['category'] == 'B'
assert ret[1]['price'] == 200

assert mb.cache.empty() is False

def test_delete(db_connection):
mb = Mybatis(db_connection, "mapper")
ret = mb.delete("testDelete", {"id":2})
Expand Down

0 comments on commit 6b7fc53

Please sign in to comment.