From 1cdfce6ceec55aba8721b6c231bd5df6462b9116 Mon Sep 17 00:00:00 2001 From: berniehuang2008 Date: Mon, 22 Jul 2024 13:49:38 +0800 Subject: [PATCH 1/2] reformat table name --- MercurySQL/gensql/database.py | 10 ++++++++-- MercurySQL/gensql/table.py | 35 +++++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/MercurySQL/gensql/database.py b/MercurySQL/gensql/database.py index c969e61..3b3ded4 100644 --- a/MercurySQL/gensql/database.py +++ b/MercurySQL/gensql/database.py @@ -128,8 +128,8 @@ def _gather_info(self): Gather all infomations of the database, including: - all tables """ - self.tables = self.driver.APIs.get_all_tables(self) - self.tables = {tname: Table(self, tname) for tname in self.tables} + tables = self.driver.APIs.get_all_tables(self) + self.tables = {tname: Table(self, tname) for tname in tables} def do(self, *sql: str, paras: List[tuple] = []): """ @@ -250,6 +250,8 @@ def createTable( already_exists = False for table_name in table_names: + table_name = self.driver.APIs.reformat_table_name(table_name) + if table_name in self.tables: already_exists = True if not force: @@ -291,6 +293,8 @@ def __getitem__(self, key: str) -> Table: .. note:: The only difference between `__getitem__()` and `createTable()` is that `__getitem__()` will return the **OLD** `Table` Object if exists, while `createTable()` will return a **NEW** `Table` Object. """ + key = self.driver.APIs.reformat_table_name(key) + if key in self.tables: return self.tables[key] else: @@ -315,6 +319,8 @@ def deleteTable(self, *table_names: str) -> None: - raise an Exception if not exists. """ for table_name in table_names: + table_name = self.driver.APIs.reformat_table_name(table_name) + if table_name not in self.tables: raise NotExistsError(f"Table `{table_name}` not exists.") diff --git a/MercurySQL/gensql/table.py b/MercurySQL/gensql/table.py index 3b5d1e8..1d9596b 100644 --- a/MercurySQL/gensql/table.py +++ b/MercurySQL/gensql/table.py @@ -51,7 +51,7 @@ def __init__(self, db, table_name: str): - set `isEmpty` to True if the table doesn't have any columns. This variable will effect how the `newColumn()` method works. """ self.db = db - self.table_name = table_name + self.table_name = self.driver.APIs.reformat_table_name(table_name) self.driver = db.driver all_tables = self.driver.APIs.get_all_tables(self.db) @@ -110,6 +110,8 @@ def __getitem__(self, key: str) -> Exp: - `_str` is used to print the definition of a column using `str(...)`. - raise an Exception if column not exists. """ + key = self.driver.APIs.reformat_column_name(key) + if key not in self.columns: raise NotExistsError(f"Column `{key}` not exists.") @@ -140,6 +142,8 @@ def __setitem__(self, key: str, value: Any) -> None: - get options from `value` if it has parameters (Judge it by whether it's a tuple, so you can use it as the L3 of example showed). - Actually create the column, using `newColumn()` method. """ + key = self.driver.APIs.reformat_column_name(key) + options = { "primary key": False, "auto increment": False, @@ -176,6 +180,8 @@ def __delitem__(self, key: str) -> None: del table['name'] # same as table.delColumn('name') """ + key = self.driver.APIs.reformat_column_name(key) + self.delColumn(key) def select(self, exp: Exp = None, selection: str = "*") -> QueryResult: @@ -237,6 +243,8 @@ def newColumn( - Set as the `primary key` column if `primaryKey` is True. - Record its name and type in `self.columns` and `self.columnsType`. """ + name = self.driver.APIs.reformat_column_name(name) + if name in self.columns: if not force: raise DuplicateError(f"Column `{name}` already exists.") @@ -267,7 +275,12 @@ def newColumn( self.columnsType[name] = type_ def struct( - self, columns: dict, skipError=True, primaryKey: str = None, autoIncrement=False, force=True + self, + columns: dict, + skipError=True, + primaryKey: str = None, + autoIncrement=False, + force=True, ) -> None: """ Set the structure of the table. @@ -295,6 +308,8 @@ def struct( skipError = skipError and force for name, type_ in columns.items(): + name = self.driver.APIs.reformat_column_name(name) + type_origin = type_ type_ = self.driver.TypeParser.parse(type_) isPrimaryKey = name == primaryKey @@ -317,6 +332,8 @@ def struct( ) def delColumn(self, name: str) -> None: + name = self.driver.APIs.reformat_column_name(name) + if name not in self.columns: # column not exist raise NotExistsError(f"Column `{name}` not exist!") @@ -346,6 +363,8 @@ def setPrimaryKey(self, keyname: str, keytype: str) -> None: table.setPrimaryKey('id') """ + keyname = self.driver.APIs.reformat_column_name(keyname) + cmd = self.driver.APIs.gensql.set_primary_key(self, keyname, keytype) self.db.do(cmd) @@ -370,6 +389,8 @@ def insert(self, __auto=False, **kwargs) -> None: if "__auto" in keys: __auto = kwargs["__auto"] keys.remove("__auto") + + keys = list(map(self.driver.APIs.reformat_column_name, keys)) columns = ", ".join(keys) values = ", ".join([self.driver.payload for _ in range(len(keys))]) @@ -384,7 +405,7 @@ def insert(self, __auto=False, **kwargs) -> None: self.db.do(cmd, paras=[tuple(kwargs[k] for k in keys)]) - def update(self, exp: Exp, data: dict={}, **kwargs) -> None: + def update(self, exp: Exp, data: dict = {}, **kwargs) -> None: """ Update the table. @@ -409,12 +430,14 @@ def update(self, exp: Exp, data: dict={}, **kwargs) -> None: table.update(table['id'] == 1, name='Bernie', age=15) # OR (table['id'] == 1).update({"name": "Bernie", "age": 15}) # recommended - + """ - + if not data: data = kwargs - + + data = {self.driver.APIs.reformat_column_name(k): v for k, v in data.items()} + columns = ", ".join([f"{key} = {self.driver.payload}" for key in data.keys()]) values = tuple(data.values()) From 45a6be96de4ce3707f394cdeeacfae3da6d0828e Mon Sep 17 00:00:00 2001 From: berniehuang2008 Date: Mon, 22 Jul 2024 13:50:02 +0800 Subject: [PATCH 2/2] modeify drivers to offer the reformat_names apis. --- MercurySQL/drivers/base.py | 87 +++++++++++++++++++++++++----------- MercurySQL/drivers/sqlite.py | 83 ++++++++++++++++++++++++++++------ 2 files changed, 131 insertions(+), 39 deletions(-) diff --git a/MercurySQL/drivers/base.py b/MercurySQL/drivers/base.py index 69ccfd7..c93817b 100644 --- a/MercurySQL/drivers/base.py +++ b/MercurySQL/drivers/base.py @@ -16,9 +16,10 @@ class BaseDriver: """ The base class of all drivers. """ - dependencies = [] # just for hints about what dependencies are needed - version = '0.0.0' - payload = '?' + + dependencies = [] # just for hints about what dependencies are needed + version = "0.0.0" + payload = "?" class Cursor: """ @@ -96,11 +97,13 @@ class APIs: Every definition will be followed by an example of it's return value (in SQLite). """ + class gensql: """ APIs in this class will return a SQL code for database operations. Generally, these returned codes will be executed **DIRECTLY**. """ + @staticmethod def drop_table(table_name: str) -> str: """ @@ -155,7 +158,13 @@ def get_all_columns(table_name: str) -> str: # return f"PRAGMA table_info({table_name});" @staticmethod - def create_table_if_not_exists(table_name: str, column_name: str, column_type: str, primaryKey=False, autoIncrement=False) -> Union[str, List[str]]: + def create_table_if_not_exists( + table_name: str, + column_name: str, + column_type: str, + primaryKey=False, + autoIncrement=False, + ) -> Union[str, List[str]]: """ Create a table if it does not exist. @@ -182,7 +191,9 @@ def create_table_if_not_exists(table_name: str, column_name: str, column_type: s # """ @staticmethod - def add_column(table_name: str, column_name: str, column_type: str) -> Union[str, List[str]]: + def add_column( + table_name: str, column_name: str, column_type: str + ) -> Union[str, List[str]]: """ Add a column to a table. @@ -232,7 +243,9 @@ def drop_column(table_name: str, column_name: str) -> Union[str, List[str]]: # """ @staticmethod - def set_primary_key(table, keyname: str, keytype: str) -> Union[str, List[str]]: + def set_primary_key( + table, keyname: str, keytype: str + ) -> Union[str, List[str]]: """ Set a primary key for the specified table. @@ -371,7 +384,7 @@ def delete(table_name: str, condition: str) -> str: # return f"DELETE FROM {table_name} WHERE {condition}" @classmethod - def get_all_tables(cls, conn: BaseDriver.Conn) -> List[str]: + def get_all_tables(cls, db) -> List[str]: """ Get all table's informations in the database. @@ -380,17 +393,18 @@ def get_all_tables(cls, conn: BaseDriver.Conn) -> List[str]: The default implementation is based on the `cls.gensql.get_all_tables()` method. - :param conn: The connection object of the database. - :type conn: BaseDriver.Conn + :param db: The DataBase object + :type db: MercurySQL.DataBase :return: All table's informations in the database """ - cursor = conn.cursor() - cursor.execute(cls.gensql.get_all_tables()) - return list(map(lambda x: x[0], cursor.fetchall())) + cursor = db.do(cls.gensql.get_all_tables()) + res = list(map(lambda x: x[0], cursor.fetchall())) + res = list(map(cls.reformat_table_name, res)) + return res @classmethod - def get_all_columns(cls, conn: BaseDriver.Conn, table_name: str) -> List[str]: + def get_all_columns(cls, db, table_name: str) -> List[str]: """ Get all column's informations in the table. @@ -399,17 +413,40 @@ def get_all_columns(cls, conn: BaseDriver.Conn, table_name: str) -> List[str]: The default implementation is based on the `cls.gensql.get_all_columns(table_name)` method. - :param conn: The connection object of the database. - :type conn: BaseDriver.Conn + :param db: The DataBase object + :type db: MercurySQL.DataBase + :param table_name: The name of the table. :type table_name: str :return: All column's informations in the table. :rtype: List[str]. Each element is a list of `[column_name, column_type]`. """ - cursor = conn.cursor() - cursor.execute(cls.gensql.get_all_columns(table_name)) - return cursor.fetchall() + cursor = db.do(cls.gensql.get_all_columns(table_name)) + res = list(map(lambda x: [x[1], x[2]], cursor.fetchall())) + res = list(map(lambda x: (cls.reformat_column_name(x[0]), x[1]), res)) + return res + + @staticmethod + def reformat_table_name(table_name: str) -> str: + """ + Reformat the table name to a valid format. + + Default implementation is based on SQLite: + - No uppercase letters. + - Only lowercase letters, numbers, and underscores are allowed. + + :param table_name: The table name to be reformatted. + :type table_name: str + + :return: The reformatted table name. + :rtype: str + """ + res = table_name.lower() + + for c in table_name: + if c not in "abcdefghijklmnopqrstuvwxyz0123456789_": + res = res.replace(c, "_") class TypeParser: """ @@ -442,11 +479,11 @@ def parse(type_: Any) -> str: """ supported_types = { - str: 'TEXT', - int: 'INTEGER', - float: 'REAL', - bool: 'BOOLEAN', - bytes: 'BLOB' + str: "TEXT", + int: "INTEGER", + float: "REAL", + bool: "BOOLEAN", + bytes: "BLOB", } # round 1: Built-in Types @@ -464,7 +501,7 @@ def parse(type_: Any) -> str: """ # round 3: Custom Types - if isinstance(type_, str): # custom type + if isinstance(type_, str): # custom type return type_ # Not Supported @@ -478,7 +515,7 @@ def connect(db_name: str, **kwargs) -> BaseDriver.Conn: :param db_name: The name of the database to connect. :type db_name: str :param kwargs: The parameters of the connection. E.g., `host`, `port`, `user`, `password`, ... - + :return: The connection object of the database. :rtype: BaseDriver.Conn """ diff --git a/MercurySQL/drivers/sqlite.py b/MercurySQL/drivers/sqlite.py index 548d812..e6bdea6 100644 --- a/MercurySQL/drivers/sqlite.py +++ b/MercurySQL/drivers/sqlite.py @@ -2,6 +2,7 @@ Requirements: - sqlite3 """ + from .base import BaseDriver import sqlite3 @@ -13,9 +14,9 @@ class Driver_SQLite(BaseDriver): class Driver_SQLite(BaseDriver): - dependencies = ['sqlite3'] - version = '0.1.0' - payload = '?' + dependencies = ["sqlite3"] + version = "0.1.1" + payload = "?" Conn = sqlite3.Connection Cursor = sqlite3.Cursor @@ -35,7 +36,13 @@ def get_all_columns(table_name: str) -> str: return f"PRAGMA table_info({table_name});" @staticmethod - def create_table_if_not_exists(table_name: str, column_name: str, column_type: str, primaryKey=False, autoIncrement=False) -> str: + def create_table_if_not_exists( + table_name: str, + column_name: str, + column_type: str, + primaryKey=False, + autoIncrement=False, + ) -> str: return f""" CREATE TABLE IF NOT EXISTS {table_name} ({column_name} {column_type} {'PRIMARY KEY' if primaryKey else ''} {'AUTOINCREMENT' if autoIncrement else ''}) """ @@ -58,7 +65,7 @@ def set_primary_key(table, keyname: str, keytype: str) -> list: f"CREATE TABLE ___temp_table ({keyname} {keytype} PRIMARY KEY, {', '.join([f'{name} {type_}' for name, type_ in table.columnsType.items() if name != keyname])})", f"INSERT INTO ___temp_table SELECT * FROM {table.table_name}", f"DROP TABLE {table.table_name}", - f"ALTER TABLE ___temp_table RENAME TO {table.table_name}" + f"ALTER TABLE ___temp_table RENAME TO {table.table_name}", ] @staticmethod @@ -67,7 +74,9 @@ def insert(table_name: str, columns: str, values: str) -> str: @staticmethod def insert_or_update(table_name: str, columns: str, values: str) -> str: - return f"INSERT OR REPLACE INTO {table_name} ({columns}) VALUES ({values})" + return ( + f"INSERT OR REPLACE INTO {table_name} ({columns}) VALUES ({values})" + ) @staticmethod def update(table_name: str, columns: str, condition: str) -> str: @@ -84,12 +93,58 @@ def delete(table_name: str, condition: str) -> str: @classmethod def get_all_tables(cls, db) -> List[str]: cursor = db.do(cls.gensql.get_all_tables()) - return list(map(lambda x: x[0], cursor.fetchall())) + res = list(map(lambda x: x[0], cursor.fetchall())) + res = list(map(cls.reformat_table_name, res)) + return res @classmethod def get_all_columns(cls, db, table_name: str) -> List[str]: cursor = db.do(cls.gensql.get_all_columns(table_name)) - return list(map(lambda x: [x[1], x[2]], cursor.fetchall())) + res = list(map(lambda x: [x[1], x[2]], cursor.fetchall())) + res = list(map(lambda x: (cls.reformat_column_name(x[0]), x[1]), res)) + return res + + @staticmethod + def reformat_table_name(table_name: str) -> str: + """ + Rules: + - Lowercase + - Replace all non-alphanumeric characters with '_' + - Not starts with digits + """ + # Rule 1: Lowercase + res = table_name.lower() + + # Rule 2: Replace all non-alphanumeric characters with '_' + for c in table_name: + if c not in "abcdefghijklmnopqrstuvwxyz0123456789_": + res = res.replace(c, '_') + + # Rule 3: Not starts with digits + if c[0] in "0123456789": + res = "_" + res + + return res + + @staticmethod + def reformat_column_name(column_name: str) -> str: + """ + Rules: + - Lowercase + - Replace all non-alphanumeric characters with '_' + - Not starts with digits + """ + res = column_name.lower() + + for c in column_name: + if c not in "abcdefghijklmnopqrstuvwxyz0123456789_": + res = res.replace(c, '_') + + if res[0] in "0123456789": + res = "_" + res + + return res + class TypeParser: """ @@ -132,11 +187,11 @@ def parse(type_: Any) -> str: """ supported_types = { - str: 'TEXT', - int: 'INTEGER', - float: 'REAL', - bool: 'BOOLEAN', - bytes: 'BLOB' + str: "TEXT", + int: "INTEGER", + float: "REAL", + bool: "BOOLEAN", + bytes: "BLOB", } # round 1: Built-in Types @@ -154,7 +209,7 @@ def parse(type_: Any) -> str: """ # round 3: Custom Types - if isinstance(type_, str): # custom type + if isinstance(type_, str): # custom type return type_ # Not Supported