Skip to content

Commit 322ccc4

Browse files
committed
Support sqlite3.
1 parent 50dd38e commit 322ccc4

File tree

8 files changed

+294
-21
lines changed

8 files changed

+294
-21
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,4 @@ cython_debug/
163163

164164
play
165165

166+
test.db

example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main():
3232
# sql, param_list = mm.insert("testInsertSelective", {'name': 'Candy', 'category': "C", 'price':500})
3333
# print(sql, param_list)
3434

35-
sql, param_list = mm.insert("test_returning_id.insert", {'name': 'Candy', 'category': "C", 'price':500, '__need_returning_id__':'fid'})
35+
sql, param_list = mm.insert("test_returning_id.insert", {'name': 'Candy', 'category': "C", 'price':500}, primary_key='fid')
3636
print(sql, param_list)
3737

3838
# cur.execute(sql, param_list, multi=True)

mybatis/connection.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
import sqlite3
23
import time
34
from abc import ABC, abstractmethod
45
from typing import Optional, Sequence
@@ -11,6 +12,10 @@
1112

1213
from .errors import DatabaseError
1314

15+
from sqlite3 import Connection as Sqlite3ConnectionRaw
16+
from sqlite3 import Cursor as Sqlite3CursorRaw
17+
18+
1419

1520
class AbstractCursor(ABC):
1621
@abstractmethod
@@ -286,6 +291,98 @@ def reconnect(self, attempts, delay):
286291
else:
287292
raise DatabaseError(str("Reconnecting failed."))
288293

294+
class Sqlite3Cursor(AbstractCursor):
295+
def __init__(self, cursor: Sqlite3CursorRaw, *args, **kwargs):
296+
self.cursor = cursor
297+
298+
def execute(self, query: str, param_list: Sequence = None):
299+
try:
300+
if param_list is None:
301+
return self.cursor.execute(query)
302+
else:
303+
return self.cursor.execute(query, param_list)
304+
except sqlite3.Error as err:
305+
raise DatabaseError(str(err))
306+
307+
def rowcount(self):
308+
return self.cursor.rowcount
309+
310+
def lastrowid(self):
311+
return self.cursor.lastrowid
312+
313+
def description(self):
314+
return self.cursor.description
315+
316+
def fetchone(self):
317+
try:
318+
return self.cursor.fetchone()
319+
except sqlite3.Error as err:
320+
raise DatabaseError(str(err))
321+
322+
def fetchall(self):
323+
try:
324+
return self.cursor.fetchall()
325+
except sqlite3.Error as err:
326+
raise DatabaseError(str(err))
327+
328+
def fetchmany(self, size: int):
329+
try:
330+
return self.cursor.fetchmany(size)
331+
except sqlite3.Error as err:
332+
raise DatabaseError(str(err))
333+
334+
def close(self):
335+
try:
336+
self.cursor.close()
337+
except sqlite3.Error as err:
338+
raise DatabaseError(str(err))
339+
340+
def __enter__(self):
341+
return self
342+
343+
def __exit__(self, exc_type, exc_val, exc_tb):
344+
self.cursor.close()
345+
if exc_type:
346+
print(f"An exception occurred: {exc_val}")
347+
return False
348+
349+
class Sqlite3Connection(AbstractConnection):
350+
def __init__(self, conn: Sqlite3ConnectionRaw):
351+
self.conn = conn
352+
353+
def cursor(self, *args, **kwargs) -> AbstractCursor:
354+
prepared = False
355+
if 'prepared' in kwargs:
356+
prepared = kwargs['prepared']
357+
del kwargs['prepared']
358+
return Sqlite3Cursor(cursor=self.conn.cursor(*args, **kwargs))
359+
360+
def close(self):
361+
self.conn.close()
362+
363+
def set_autocommit(self, autocommit: bool):
364+
pass
365+
366+
def start_transaction(self):
367+
pass
368+
369+
def commit(self):
370+
try:
371+
self.conn.commit()
372+
except sqlite3.Error as err:
373+
raise DatabaseError(str(err))
374+
375+
def rollback(self):
376+
try:
377+
self.conn.rollback()
378+
except sqlite3.Error as err:
379+
raise DatabaseError(str(err))
380+
381+
def need_returning_id(self):
382+
return False
383+
384+
def reconnect(self, attempts, delay):
385+
pass
289386

290387
class ConnectionFactory(ABC):
291388
@staticmethod
@@ -304,5 +401,10 @@ def get_connection(*args, **kwargs) -> Optional[AbstractConnection]:
304401
ret_conn = PostgreSQLConnection(conn)
305402
ret_conn.__setattr__("connect_kwargs", kwargs)
306403
return ret_conn
404+
elif dbms_name == 'sqlite3':
405+
db_path = kwargs.get("db_path")
406+
conn = sqlite3.connect(db_path)
407+
ret_conn = Sqlite3Connection(conn)
408+
return ret_conn
307409
else:
308410
raise NotImplementedError(dbms_name)

mybatis/mapper_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import re
44

55
class MapperManager:
6-
def __init__(self, postgresql_primary_key_name=None):
6+
def __init__(self):
77
self.id_2_element_map = {}
88
self.param_pattern = re.compile(r"#{([a-zA-Z0-9_\-]+)}")
99
self.replace_pattern = re.compile(r"\${([a-zA-Z0-9_\-]+)}")
10-
self.postgresql_primary_key_name = postgresql_primary_key_name
1110

1211
def read_mapper_xml_file(self, mapper_xml_file_path):
1312
namespace = ""
@@ -340,7 +339,7 @@ def delete(self, id: str, params: dict) -> Tuple[str, list]:
340339
sql = self._to_replace(sql, params)
341340
return (sql, sql_param)
342341

343-
def insert(self, id: str, params: dict) -> Tuple[str, list]:
342+
def insert(self, id: str, params: dict, primary_key:str=None) -> Tuple[str, list]:
344343
if id not in self.id_2_element_map:
345344
raise Exception("Missing id")
346345
element = self.id_2_element_map[id]
@@ -358,7 +357,7 @@ def insert(self, id: str, params: dict) -> Tuple[str, list]:
358357
sql, sql_param = self._to_prepared_statement(ret, params)
359358
sql = self._to_replace(sql, params)
360359

361-
if self.postgresql_primary_key_name:
362-
sql += (" RETURNING "+str(self.postgresql_primary_key_name))
360+
if primary_key:
361+
sql += (" RETURNING "+str(primary_key))
363362

364363
return (sql, sql_param)

mybatis/mybatis.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ def fetch_rows(cursor, batch_size=1000):
2828

2929
class Mybatis(object):
3030
def __init__(self, conn:AbstractConnection, mapper_path:str, cache_memory_limit:Optional[int]=None, cache_max_live_ms:int=5*1000,
31-
max_result_bytes:int=100*1024*1024, postgresql_primary_key_name=None):
31+
max_result_bytes:int=100*1024*1024):
3232
self.conn = conn
33-
self.postgresql_primary_key_name = postgresql_primary_key_name
34-
self.mapper_manager = MapperManager(postgresql_primary_key_name)
33+
self.mapper_manager = MapperManager()
3534
self.max_result_bytes = max_result_bytes
3635

3736
if cache_memory_limit is not None:
@@ -46,6 +45,7 @@ def __init__(self, conn:AbstractConnection, mapper_path:str, cache_memory_limit:
4645

4746
def select_one(self, id:str, params:dict) -> Optional[Dict]:
4847
sql, param_list = self.mapper_manager.select(id, params)
48+
4949
if self.cache.memory_limit > 0:
5050
res = self.cache.get(CacheKey(sql, param_list))
5151
if res is not None:
@@ -144,7 +144,9 @@ def insert(self, id:str, params:dict, primary_key:str=None) -> int:
144144
if self.conn.need_returning_id() and primary_key:
145145
params['__need_returning_id__'] = str(primary_key)
146146

147-
sql, param_list = self.mapper_manager.insert(id, params)
147+
sql, param_list = self.mapper_manager.insert(id, params, primary_key)
148+
149+
print("========>",sql,param_list)
148150

149151
res = self.cache.clear()
150152

test/test_mybatis_decorator_postgresql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def delete(id:int):
201201
assert ret[0]['price'] == 100
202202

203203
def test_insert(postgresql_db_connection):
204-
mb = Mybatis(postgresql_db_connection, "mapper", postgresql_primary_key_name="id")
204+
mb = Mybatis(postgresql_db_connection, "mapper")
205205

206206
@mb.SelectMany("SELECT * FROM fruits")
207207
def select_many():

test/test_mybatis_postgresql.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,6 @@
66

77
@pytest.fixture(scope="function")
88
def db_connection():
9-
# 配置数据库连接
10-
# connection = mysql.connector.connect(
11-
# host="localhost",
12-
# user="mybatis",
13-
# password="mybatis",
14-
# database="mybatis",
15-
# autocommit=False,
16-
# )
179
connection = ConnectionFactory.get_connection(
1810
dbms_name='postgresql',
1911
host="localhost",
@@ -167,8 +159,8 @@ def test_delete(db_connection):
167159
assert ret[0]['price'] == 100
168160

169161
def test_insert(db_connection):
170-
mb = Mybatis(db_connection, "mapper", postgresql_primary_key_name="id")
171-
ret = mb.insert("testInsert", {"name": "Candy", "category": "B", "price": 200})
162+
mb = Mybatis(db_connection, "mapper")
163+
ret = mb.insert("testInsert", {"name": "Candy", "category": "B", "price": 200}, primary_key="id")
172164
assert ret == 3
173165

174166
ret = mb.select_many('testBasicMany', {})

0 commit comments

Comments
 (0)