diff --git a/dfs_generate/server.py b/dfs_generate/server.py index 91f6c3f..36705fc 100644 --- a/dfs_generate/server.py +++ b/dfs_generate/server.py @@ -1,16 +1,15 @@ import os -from typing import Dict import bottle import isort from yapf.yapflib.yapf_api import FormatCode from dfs_generate.conversion import SQLModelConversion, TortoiseConversion -from dfs_generate.tools import MySQLConf, MySQLHelper +from dfs_generate.tools import Cache, MySQLConf, MySQLHelper app = bottle.Bottle() - -CACHE: Dict[str, MySQLConf] = {} +cache = Cache() +cache.start() # 解决打包桌面程序static找不到的问题 static_file_abspath = os.path.join( @@ -47,22 +46,30 @@ def index(): return html_content -@app.post("/conf") +# 连接数据库 +@app.get("/con") def connect(): + if cache.get(): + return {"code": 20000, "msg": "ok", "data": cache.get()} + return {"code": 40000, "msg": "error", "data": None} + + +@app.post("/conf") +def configure(): payload = bottle.request.json try: - with MySQLHelper(MySQLConf(**payload)) as obj: - CACHE["conf"] = MySQLConf(**payload) + with MySQLHelper(MySQLConf(**payload)): + cache.set(**MySQLConf(**payload).json()) return {"code": 20000, "msg": "ok", "data": None} except Exception as e: return {"code": 40000, "msg": str(e), "data": None} @app.get("/tables") -def tables(): +def get_tables(): like = bottle.request.query.get("tableName") try: - with MySQLHelper(CACHE.get("conf")) as obj: + with MySQLHelper(MySQLConf(**cache.get())) as obj: data = [ {"tableName": table, "key": table} for table in obj.get_tables() @@ -83,9 +90,9 @@ def codegen(): else: _instance = TortoiseConversion try: - with MySQLHelper(CACHE.get("conf")) as obj: + with MySQLHelper(MySQLConf(**cache.get())) as obj: data = _instance( - table, obj.get_table_columns(table), obj.conf.get_db_uri() + table, obj.get_table_columns(table), obj.conf.db_uri ).gencode() except Exception as e: return {"code": 40000, "msg": str(e), "data": None} diff --git a/dfs_generate/tools.py b/dfs_generate/tools.py index d30b537..124bfd1 100644 --- a/dfs_generate/tools.py +++ b/dfs_generate/tools.py @@ -1,5 +1,7 @@ +import os import re -from dataclasses import dataclass, asdict +import sqlite3 +from dataclasses import asdict, dataclass import pymysql @@ -74,7 +76,8 @@ class MySQLConf: port: int = 3306 charset: str = "utf8" - def get_db_uri(self): + @property + def db_uri(self): return f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}?charset={self.charset}" def json(self): @@ -102,13 +105,6 @@ def __init__(self, conf: MySQLConf): ) self.cursor = self.conn.cursor() - def set_conn(self, conf: MySQLConf): - self.conf = conf - self.conn = pymysql.connect( - **self.conf.json(), cursorclass=pymysql.cursors.DictCursor - ) - self.cursor = self.conn.cursor() - def close(self): self.cursor.close() self.conn.close() @@ -130,3 +126,71 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + + +def get_cache_directory(): + """ + 获取适用于不同操作系统的缓存目录路径。 + """ + system = os.name + if system == "posix": # Linux, macOS, Unix + return os.path.expanduser("~/.cache") + elif system == "nt": # Windows + return os.path.expandvars(r"%LOCALAPPDATA%") + else: + return "." # 不支持的操作系统 + + +class Cache: + system = os.name + + def __init__(self): + if self.system == "posix": # Linux, macOS, Unix + cache_dir = os.path.expanduser("~/.cache") + elif self.system == "nt": # Windows + cache_dir = os.path.expandvars(r"%LOCALAPPDATA%") + else: + cache_dir = "." + app_cache = os.path.join(cache_dir, "dfs-generate") + if not os.path.isdir(app_cache): + os.mkdir(app_cache) + self.db_path = os.path.join(app_cache, ".data.db") + + def start(self): + create_table_sql = """ + CREATE TABLE IF NOT EXISTS conf ( + id INTEGER PRIMARY KEY, + user TEXT NOT NULL, + password TEXT NOT NULL, + host TEXT NOT NULL, + port INT NOT NULL, + db TEXT NOT NULL, + charset TEXT NOT NULL + ); + """ + with sqlite3.connect(self.db_path, check_same_thread=False) as conn: + cursor = conn.cursor() + cursor.execute(create_table_sql) + conn.commit() + + def set(self, user, password, host, port, db, charset): + with sqlite3.connect(self.db_path, check_same_thread=False) as conn: + cursor = conn.cursor() + insert_sql = """ + INSERT INTO conf (user, password, host, port, db, charset) VALUES (?, ?, ?, ?, ?, ?) + """ + cursor.execute(insert_sql, (user, password, host, port, db, charset)) + conn.commit() + + def get(self): + with sqlite3.connect(self.db_path, check_same_thread=False) as conn: + cursor = conn.cursor() + query_sql = """ + SELECT user, password, host, port, db, charset FROM conf ORDER BY id DESC LIMIT 1 + """ + cursor.execute(query_sql) + result = cursor.fetchone() + if result: + keys = ["user", "password", "host", "port", "db", "charset"] + return dict(zip(keys, result)) + return None diff --git a/tests/test_tools.py b/tests/test_tools.py index 045bfc5..e5ed129 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -49,7 +49,7 @@ def test_mysqlconf_get_db_uri(): host="localhost", user="test_user", password="secure_pwd", db="test_db" ) assert ( - conf.get_db_uri() + conf.db_uri == "mysql+pymysql://test_user:secure_pwd@localhost:3306/test_db?charset=utf8" ) diff --git a/web/src/compoents/content.jsx b/web/src/compoents/content.jsx index 14b6046..b187a3a 100644 --- a/web/src/compoents/content.jsx +++ b/web/src/compoents/content.jsx @@ -7,39 +7,30 @@ import { Input, InputNumber, message, - Affix, -} from "antd"; -import { useEffect, useState } from "react"; + Affix +} from 'antd' +import { useEffect, useState } from 'react' import { CodepenOutlined, SettingOutlined, - RedditOutlined, -} from "@ant-design/icons"; -import CodeGenerate from "../compoents/codegen"; -import { host } from "../conf"; + RedditOutlined +} from '@ant-design/icons' +import CodeGenerate from '../compoents/codegen' +import { host } from '../conf' -const changDBFormRules = [{ required: true, message: "该项必须填写" }]; +const changDBFormRules = [{ required: true, message: '该项必须填写' }] // 修改配置组件 -const ChangeDB = ({ onDbFinsh }) => { +const ChangeDB = ({ onDbFinsh, initialValues }) => { return (