Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 62 additions & 25 deletions MercurySQL/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ class BaseDriver:
"""
The base class of all drivers.
"""
dependencies = [] # just for hints about what dependencies are needed
version = '0.0.0'
payload = '?'

dependencies = [] # just for hints about what dependencies are needed
version = "0.0.0"
payload = "?"

class Cursor:
"""
Expand Down Expand Up @@ -96,11 +97,13 @@ class APIs:

Every definition will be followed by an example of it's return value (in SQLite).
"""

class gensql:
"""
APIs in this class will return a SQL code for database operations.
Generally, these returned codes will be executed **DIRECTLY**.
"""

@staticmethod
def drop_table(table_name: str) -> str:
"""
Expand Down Expand Up @@ -155,7 +158,13 @@ def get_all_columns(table_name: str) -> str:
# return f"PRAGMA table_info({table_name});"

@staticmethod
def create_table_if_not_exists(table_name: str, column_name: str, column_type: str, primaryKey=False, autoIncrement=False) -> Union[str, List[str]]:
def create_table_if_not_exists(
table_name: str,
column_name: str,
column_type: str,
primaryKey=False,
autoIncrement=False,
) -> Union[str, List[str]]:
"""
Create a table if it does not exist.

Expand All @@ -182,7 +191,9 @@ def create_table_if_not_exists(table_name: str, column_name: str, column_type: s
# """

@staticmethod
def add_column(table_name: str, column_name: str, column_type: str) -> Union[str, List[str]]:
def add_column(
table_name: str, column_name: str, column_type: str
) -> Union[str, List[str]]:
"""
Add a column to a table.

Expand Down Expand Up @@ -232,7 +243,9 @@ def drop_column(table_name: str, column_name: str) -> Union[str, List[str]]:
# """

@staticmethod
def set_primary_key(table, keyname: str, keytype: str) -> Union[str, List[str]]:
def set_primary_key(
table, keyname: str, keytype: str
) -> Union[str, List[str]]:
"""
Set a primary key for the specified table.

Expand Down Expand Up @@ -371,7 +384,7 @@ def delete(table_name: str, condition: str) -> str:
# return f"DELETE FROM {table_name} WHERE {condition}"

@classmethod
def get_all_tables(cls, conn: BaseDriver.Conn) -> List[str]:
def get_all_tables(cls, db) -> List[str]:
"""
Get all table's informations in the database.

Expand All @@ -380,17 +393,18 @@ def get_all_tables(cls, conn: BaseDriver.Conn) -> List[str]:

The default implementation is based on the `cls.gensql.get_all_tables()` method.

:param conn: The connection object of the database.
:type conn: BaseDriver.Conn
:param db: The DataBase object
:type db: MercurySQL.DataBase

:return: All table's informations in the database
"""
cursor = conn.cursor()
cursor.execute(cls.gensql.get_all_tables())
return list(map(lambda x: x[0], cursor.fetchall()))
cursor = db.do(cls.gensql.get_all_tables())
res = list(map(lambda x: x[0], cursor.fetchall()))
res = list(map(cls.reformat_table_name, res))
return res

@classmethod
def get_all_columns(cls, conn: BaseDriver.Conn, table_name: str) -> List[str]:
def get_all_columns(cls, db, table_name: str) -> List[str]:
"""
Get all column's informations in the table.

Expand All @@ -399,17 +413,40 @@ def get_all_columns(cls, conn: BaseDriver.Conn, table_name: str) -> List[str]:

The default implementation is based on the `cls.gensql.get_all_columns(table_name)` method.

:param conn: The connection object of the database.
:type conn: BaseDriver.Conn
:param db: The DataBase object
:type db: MercurySQL.DataBase

:param table_name: The name of the table.
:type table_name: str

:return: All column's informations in the table.
:rtype: List[str]. Each element is a list of `[column_name, column_type]`.
"""
cursor = conn.cursor()
cursor.execute(cls.gensql.get_all_columns(table_name))
return cursor.fetchall()
cursor = db.do(cls.gensql.get_all_columns(table_name))
res = list(map(lambda x: [x[1], x[2]], cursor.fetchall()))
res = list(map(lambda x: (cls.reformat_column_name(x[0]), x[1]), res))
return res

@staticmethod
def reformat_table_name(table_name: str) -> str:
"""
Reformat the table name to a valid format.

Default implementation is based on SQLite:
- No uppercase letters.
- Only lowercase letters, numbers, and underscores are allowed.

:param table_name: The table name to be reformatted.
:type table_name: str

:return: The reformatted table name.
:rtype: str
"""
res = table_name.lower()

for c in table_name:
if c not in "abcdefghijklmnopqrstuvwxyz0123456789_":
res = res.replace(c, "_")

class TypeParser:
"""
Expand Down Expand Up @@ -442,11 +479,11 @@ def parse(type_: Any) -> str:

"""
supported_types = {
str: 'TEXT',
int: 'INTEGER',
float: 'REAL',
bool: 'BOOLEAN',
bytes: 'BLOB'
str: "TEXT",
int: "INTEGER",
float: "REAL",
bool: "BOOLEAN",
bytes: "BLOB",
}

# round 1: Built-in Types
Expand All @@ -464,7 +501,7 @@ def parse(type_: Any) -> str:
"""

# round 3: Custom Types
if isinstance(type_, str): # custom type
if isinstance(type_, str): # custom type
return type_

# Not Supported
Expand All @@ -478,7 +515,7 @@ def connect(db_name: str, **kwargs) -> BaseDriver.Conn:
:param db_name: The name of the database to connect.
:type db_name: str
:param kwargs: The parameters of the connection. E.g., `host`, `port`, `user`, `password`, ...

:return: The connection object of the database.
:rtype: BaseDriver.Conn
"""
Expand Down
83 changes: 69 additions & 14 deletions MercurySQL/drivers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Requirements:
- sqlite3
"""

from .base import BaseDriver

import sqlite3
Expand All @@ -13,9 +14,9 @@ class Driver_SQLite(BaseDriver):


class Driver_SQLite(BaseDriver):
dependencies = ['sqlite3']
version = '0.1.0'
payload = '?'
dependencies = ["sqlite3"]
version = "0.1.1"
payload = "?"

Conn = sqlite3.Connection
Cursor = sqlite3.Cursor
Expand All @@ -35,7 +36,13 @@ def get_all_columns(table_name: str) -> str:
return f"PRAGMA table_info({table_name});"

@staticmethod
def create_table_if_not_exists(table_name: str, column_name: str, column_type: str, primaryKey=False, autoIncrement=False) -> str:
def create_table_if_not_exists(
table_name: str,
column_name: str,
column_type: str,
primaryKey=False,
autoIncrement=False,
) -> str:
return f"""
CREATE TABLE IF NOT EXISTS {table_name} ({column_name} {column_type} {'PRIMARY KEY' if primaryKey else ''} {'AUTOINCREMENT' if autoIncrement else ''})
"""
Expand All @@ -58,7 +65,7 @@ def set_primary_key(table, keyname: str, keytype: str) -> list:
f"CREATE TABLE ___temp_table ({keyname} {keytype} PRIMARY KEY, {', '.join([f'{name} {type_}' for name, type_ in table.columnsType.items() if name != keyname])})",
f"INSERT INTO ___temp_table SELECT * FROM {table.table_name}",
f"DROP TABLE {table.table_name}",
f"ALTER TABLE ___temp_table RENAME TO {table.table_name}"
f"ALTER TABLE ___temp_table RENAME TO {table.table_name}",
]

@staticmethod
Expand All @@ -67,7 +74,9 @@ def insert(table_name: str, columns: str, values: str) -> str:

@staticmethod
def insert_or_update(table_name: str, columns: str, values: str) -> str:
return f"INSERT OR REPLACE INTO {table_name} ({columns}) VALUES ({values})"
return (
f"INSERT OR REPLACE INTO {table_name} ({columns}) VALUES ({values})"
)

@staticmethod
def update(table_name: str, columns: str, condition: str) -> str:
Expand All @@ -84,12 +93,58 @@ def delete(table_name: str, condition: str) -> str:
@classmethod
def get_all_tables(cls, db) -> List[str]:
cursor = db.do(cls.gensql.get_all_tables())
return list(map(lambda x: x[0], cursor.fetchall()))
res = list(map(lambda x: x[0], cursor.fetchall()))
res = list(map(cls.reformat_table_name, res))
return res

@classmethod
def get_all_columns(cls, db, table_name: str) -> List[str]:
cursor = db.do(cls.gensql.get_all_columns(table_name))
return list(map(lambda x: [x[1], x[2]], cursor.fetchall()))
res = list(map(lambda x: [x[1], x[2]], cursor.fetchall()))
res = list(map(lambda x: (cls.reformat_column_name(x[0]), x[1]), res))
return res

@staticmethod
def reformat_table_name(table_name: str) -> str:
"""
Rules:
- Lowercase
- Replace all non-alphanumeric characters with '_'
- Not starts with digits
"""
# Rule 1: Lowercase
res = table_name.lower()

# Rule 2: Replace all non-alphanumeric characters with '_'
for c in table_name:
if c not in "abcdefghijklmnopqrstuvwxyz0123456789_":
res = res.replace(c, '_')

# Rule 3: Not starts with digits
if c[0] in "0123456789":
res = "_" + res

return res

@staticmethod
def reformat_column_name(column_name: str) -> str:
"""
Rules:
- Lowercase
- Replace all non-alphanumeric characters with '_'
- Not starts with digits
"""
res = column_name.lower()

for c in column_name:
if c not in "abcdefghijklmnopqrstuvwxyz0123456789_":
res = res.replace(c, '_')

if res[0] in "0123456789":
res = "_" + res

return res


class TypeParser:
"""
Expand Down Expand Up @@ -132,11 +187,11 @@ def parse(type_: Any) -> str:

"""
supported_types = {
str: 'TEXT',
int: 'INTEGER',
float: 'REAL',
bool: 'BOOLEAN',
bytes: 'BLOB'
str: "TEXT",
int: "INTEGER",
float: "REAL",
bool: "BOOLEAN",
bytes: "BLOB",
}

# round 1: Built-in Types
Expand All @@ -154,7 +209,7 @@ def parse(type_: Any) -> str:
"""

# round 3: Custom Types
if isinstance(type_, str): # custom type
if isinstance(type_, str): # custom type
return type_

# Not Supported
Expand Down
10 changes: 8 additions & 2 deletions MercurySQL/gensql/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def _gather_info(self):
Gather all infomations of the database, including:
- all tables
"""
self.tables = self.driver.APIs.get_all_tables(self)
self.tables = {tname: Table(self, tname) for tname in self.tables}
tables = self.driver.APIs.get_all_tables(self)
self.tables = {tname: Table(self, tname) for tname in tables}

def do(self, *sql: str, paras: List[tuple] = []):
"""
Expand Down Expand Up @@ -250,6 +250,8 @@ def createTable(
already_exists = False

for table_name in table_names:
table_name = self.driver.APIs.reformat_table_name(table_name)

if table_name in self.tables:
already_exists = True
if not force:
Expand Down Expand Up @@ -291,6 +293,8 @@ def __getitem__(self, key: str) -> Table:

.. note:: The only difference between `__getitem__()` and `createTable()` is that `__getitem__()` will return the **OLD** `Table` Object if exists, while `createTable()` will return a **NEW** `Table` Object.
"""
key = self.driver.APIs.reformat_table_name(key)

if key in self.tables:
return self.tables[key]
else:
Expand All @@ -315,6 +319,8 @@ def deleteTable(self, *table_names: str) -> None:
- raise an Exception if not exists.
"""
for table_name in table_names:
table_name = self.driver.APIs.reformat_table_name(table_name)

if table_name not in self.tables:
raise NotExistsError(f"Table `{table_name}` not exists.")

Expand Down
Loading