Skip to content

Commit d035ac4

Browse files
committed
Support PostgreSQL.
1 parent e4c406a commit d035ac4

12 files changed

+950
-59
lines changed

.github/workflows/ci.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ jobs:
1515
runs-on: ubuntu-latest
1616

1717
services:
18+
postgres:
19+
image: postgres:16
20+
env:
21+
POSTGRES_USER: mybatis
22+
POSTGRES_PASSWORD: mybatis
23+
POSTGRES_DB: mybatis
24+
ports:
25+
- 5432:5432
1826
mysql:
1927
image: mysql:5.7
2028
env:
@@ -36,7 +44,15 @@ jobs:
3644

3745
- name: Install dependencies
3846
run: |
39-
pip install pytest pytest-cov mysql-connector-python Pympler orjson
47+
pip install pytest pytest-cov mysql-connector-python Pympler orjson psycopg2-binary
48+
49+
- name: Wait for PostgreSQL to be ready
50+
run: |
51+
until pg_isready -h localhost -p 5432 -U mybatis; do
52+
echo "Waiting for PostgreSQL..."
53+
sleep 2
54+
done
55+
echo 'PostgreSQL is ready!'
4056
4157
- name: Set up MySQL client
4258
run: sudo apt-get install mysql-client

example2.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from mybatis import *
2-
import mysql.connector
32

43
def main():
54
# 连接到 MySQL 数据库
6-
conn = mysql.connector.connect(
7-
host="localhost", # MySQL 主机地址
8-
user="mybatis", # MySQL 用户名
9-
password="mybatis", # MySQL 密码
10-
database="mybatis" # 需要连接的数据库
11-
)
5+
# conn = mysql.connector.connect(
6+
# host="localhost", # MySQL 主机地址
7+
# user="mybatis", # MySQL 用户名
8+
# password="mybatis", # MySQL 密码
9+
# database="mybatis" # 需要连接的数据库
10+
# )
11+
conn = ConnectionFactory.get_connection(dbms_name="postgresql",
12+
host="localhost",
13+
user="mybatis",
14+
password="mybatis",
15+
database="mybatis")
1216

1317
mb = Mybatis(conn, "mapper", cache_memory_limit=50*1024*1024)
1418

mybatis/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .mapper_manager import MapperManager
22
from .mybatis import Mybatis
3-
from .cache import Cache, CacheKey
3+
from .cache import Cache, CacheKey
4+
from .connection import AbstractConnection, AbstractCursor, MySQLConnection, MySQLCursor, ConnectionFactory

mybatis/connection.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import re
2+
from abc import ABC, abstractmethod
3+
from typing import Optional, Sequence
4+
5+
import psycopg2
6+
from mysql.connector.abstracts import MySQLConnectionAbstract, MySQLCursorAbstract
7+
import mysql.connector
8+
from psycopg2.extensions import connection as PostgreSQLConnectionRaw
9+
from psycopg2.extensions import cursor as PostgreSQLCursorRaw
10+
11+
class AbstractCursor(ABC):
12+
@abstractmethod
13+
def execute(self, query: str, param_list : Sequence = None):
14+
pass
15+
16+
@abstractmethod
17+
def rowcount(self):
18+
pass
19+
20+
@abstractmethod
21+
def lastrowid(self):
22+
pass
23+
24+
@abstractmethod
25+
def description(self):
26+
pass
27+
28+
@abstractmethod
29+
def fetchone(self):
30+
pass
31+
32+
@abstractmethod
33+
def fetchall(self):
34+
pass
35+
36+
@abstractmethod
37+
def fetchmany(self, size: int):
38+
pass
39+
40+
@abstractmethod
41+
def close(self):
42+
pass
43+
44+
class AbstractConnection(ABC):
45+
@abstractmethod
46+
def cursor(self, *args, **kwargs) -> AbstractCursor:
47+
pass
48+
49+
@abstractmethod
50+
def close(self):
51+
pass
52+
53+
@abstractmethod
54+
def set_autocommit(self, autocommit: bool):
55+
pass
56+
57+
@abstractmethod
58+
def start_transaction(self):
59+
pass
60+
61+
@abstractmethod
62+
def commit(self):
63+
pass
64+
65+
@abstractmethod
66+
def rollback(self):
67+
pass
68+
69+
70+
class MySQLCursor(AbstractCursor):
71+
def __init__(self, cursor: MySQLCursorAbstract, *args, **kwargs):
72+
self.cursor = cursor
73+
74+
def execute(self, query: str, param_list : Sequence = None):
75+
return self.cursor.execute(query, param_list)
76+
77+
def rowcount(self):
78+
return self.cursor.rowcount
79+
80+
def lastrowid(self):
81+
return self.cursor.lastrowid
82+
83+
def description(self):
84+
return self.cursor.description
85+
86+
def fetchone(self):
87+
return self.cursor.fetchone()
88+
89+
def fetchall(self):
90+
return self.cursor.fetchall()
91+
92+
def fetchmany(self, size: int):
93+
return self.cursor.fetchmany(size)
94+
95+
def close(self):
96+
self.cursor.close()
97+
98+
def __enter__(self):
99+
return self
100+
101+
def __exit__(self, exc_type, exc_val, exc_tb):
102+
self.cursor.close()
103+
if exc_type:
104+
print(f"An exception occurred: {exc_val}")
105+
return False
106+
107+
108+
class MySQLConnection(AbstractConnection):
109+
def __init__(self, conn: MySQLConnectionAbstract):
110+
self.conn = conn
111+
112+
def cursor(self, *args, **kwargs) -> AbstractCursor:
113+
return MySQLCursor(cursor=self.conn.cursor(*args, **kwargs))
114+
115+
def close(self):
116+
self.conn.close()
117+
118+
def set_autocommit(self, autocommit: bool):
119+
self.conn.autocommit = autocommit
120+
121+
def start_transaction(self):
122+
self.set_autocommit(False)
123+
124+
def commit(self):
125+
self.conn.commit()
126+
127+
def rollback(self):
128+
self.conn.rollback()
129+
130+
class PostgreSQLCursor(AbstractCursor):
131+
def __init__(self, cursor: PostgreSQLCursorRaw, prepared=False):
132+
self.prepared = prepared
133+
self.cursor = cursor
134+
self.replace_pattern = re.compile(r"\?")
135+
136+
def execute(self, query: str, param_list:Sequence = None):
137+
query = self.replace_pattern.sub("%s", query)
138+
return self.cursor.execute(query, param_list)
139+
140+
def rowcount(self):
141+
return self.cursor.rowcount
142+
143+
def lastrowid(self):
144+
return self.cursor.fetchone()[0]
145+
146+
def description(self):
147+
return self.cursor.description
148+
149+
def fetchone(self):
150+
return self.cursor.fetchone()
151+
152+
def fetchall(self):
153+
return self.cursor.fetchall()
154+
155+
def fetchmany(self, size: int):
156+
return self.cursor.fetchmany(size)
157+
158+
def close(self):
159+
self.cursor.close()
160+
161+
def __enter__(self):
162+
return self
163+
164+
def __exit__(self, exc_type, exc_val, exc_tb):
165+
self.cursor.close()
166+
if exc_type:
167+
print(f"An exception occurred: {exc_val}")
168+
return False
169+
170+
171+
class PostgreSQLConnection(AbstractConnection):
172+
def __init__(self, conn: PostgreSQLConnectionRaw):
173+
self.conn = conn
174+
self.prepared = False
175+
176+
def cursor(self, *args, **kwargs) -> AbstractCursor:
177+
prepared = False
178+
if 'prepared' in kwargs:
179+
prepared = kwargs['prepared']
180+
del kwargs['prepared']
181+
return PostgreSQLCursor(cursor=self.conn.cursor(*args, **kwargs), prepared=prepared)
182+
183+
def close(self):
184+
self.conn.close()
185+
186+
def set_autocommit(self, autocommit: bool):
187+
self.conn.autocommit = autocommit
188+
189+
def start_transaction(self):
190+
pass
191+
192+
def commit(self):
193+
self.conn.commit()
194+
195+
def rollback(self):
196+
self.conn.rollback()
197+
198+
199+
class ConnectionFactory(ABC):
200+
@staticmethod
201+
def get_connection(*args, **kwargs) -> Optional[AbstractConnection]:
202+
dbms_name = kwargs.get("dbms_name")
203+
del kwargs['dbms_name']
204+
if dbms_name == 'mysql':
205+
conn = mysql.connector.connect(
206+
**kwargs
207+
)
208+
return MySQLConnection(conn)
209+
elif dbms_name == 'postgresql':
210+
conn = psycopg2.connect(
211+
**kwargs
212+
)
213+
return PostgreSQLConnection(conn)

mybatis/mapper_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import re
44

55
class MapperManager:
6-
def __init__(self):
6+
def __init__(self, postgresql_primary_key_name=None):
77
self.id_2_element_map = {}
88
self.param_pattern = re.compile(r"#{([a-zA-Z0-9_\-]+)}")
99
self.replace_pattern = re.compile(r"\${([a-zA-Z0-9_\-]+)}")
10+
self.postgresql_primary_key_name = postgresql_primary_key_name
1011

1112
def read_mapper_xml_file(self, mapper_xml_file_path):
1213
namespace = ""
@@ -356,4 +357,8 @@ def insert(self, id: str, params: dict) -> Tuple[str, list]:
356357

357358
sql, sql_param = self._to_prepared_statement(ret, params)
358359
sql = self._to_replace(sql, params)
360+
361+
if self.postgresql_primary_key_name:
362+
sql += (" RETURNING "+str(self.postgresql_primary_key_name))
363+
359364
return (sql, sql_param)

0 commit comments

Comments
 (0)