Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
Upcoming (TBD)
==============

Internal
--------
* Create new data class to handle SQL/command results to make further code improvements easier


1.44.1 (2026/01/10)
==============

Expand Down
7 changes: 4 additions & 3 deletions mycli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable

from mycli.packages.special.main import COMMANDS
from mycli.packages.sqlresult import SQLResult
from mycli.sqlcompleter import SQLCompleter
from mycli.sqlexecute import ServerSpecies, SQLExecute

Expand All @@ -18,7 +19,7 @@ def refresh(
executor: SQLExecute,
callbacks: Callable | list[Callable],
completer_options: dict | None = None,
) -> list[tuple]:
) -> list[SQLResult]:
"""Creates a SQLCompleter object and populates it with the relevant
completion suggestions in a background thread.

Expand All @@ -35,14 +36,14 @@ def refresh(

if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
return [SQLResult(status="Auto-completion refresh restarted.")]
else:
self._completer_thread = threading.Thread(
target=self._bg_refresh, args=(executor, callbacks, completer_options), name="completion_refresh"
)
self._completer_thread.daemon = True
self._completer_thread.start()
return [(None, None, None, "Auto-completion refresh started in the background.")]
return [SQLResult(status="Auto-completion refresh started in the background.")]

def is_refreshing(self) -> bool:
return bool(self._completer_thread and self._completer_thread.is_alive())
Expand Down
56 changes: 26 additions & 30 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from mycli.packages.prompt_utils import confirm, confirm_destructive_query
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import ArgType
from mycli.packages.sqlresult import SQLResult
from mycli.packages.tabular_output import sql_format
from mycli.packages.toolkit.history import FileHistoryWithTimestamp
from mycli.sqlcompleter import SQLCompleter
Expand Down Expand Up @@ -274,49 +275,49 @@ def register_special_commands(self) -> None:
self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True
)

def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]:
def manual_reconnect(self, arg: str = "", **_) -> Generator[SQLResult, None, None]:
"""
Interactive method to use for the \r command, so that the utility method
may be cleanly used elsewhere.
"""
if not self.reconnect(database=arg):
yield (None, None, None, "Not connected")
yield SQLResult(status="Not connected")
elif not arg or arg == '``':
yield (None, None, None, None)
yield SQLResult()
else:
yield self.change_db(arg).send(None)

def enable_show_warnings(self, **_) -> Generator[tuple, None, None]:
def enable_show_warnings(self, **_) -> Generator[SQLResult, None, None]:
self.show_warnings = True
msg = "Show warnings enabled."
yield (None, None, None, msg)
yield SQLResult(status=msg)

def disable_show_warnings(self, **_) -> Generator[tuple, None, None]:
def disable_show_warnings(self, **_) -> Generator[SQLResult, None, None]:
self.show_warnings = False
msg = "Show warnings disabled."
yield (None, None, None, msg)
yield SQLResult(status=msg)

def change_table_format(self, arg: str, **_) -> Generator[tuple, None, None]:
def change_table_format(self, arg: str, **_) -> Generator[SQLResult, None, None]:
try:
self.main_formatter.format_name = arg
yield (None, None, None, f"Changed table format to {arg}")
yield SQLResult(status=f"Changed table format to {arg}")
except ValueError:
msg = f"Table format {arg} not recognized. Allowed formats:"
for table_type in self.main_formatter.supported_formats:
msg += f"\n\t{table_type}"
yield (None, None, None, msg)
yield SQLResult(status=msg)

def change_redirect_format(self, arg: str, **_) -> Generator[tuple, None, None]:
def change_redirect_format(self, arg: str, **_) -> Generator[SQLResult, None, None]:
try:
self.redirect_formatter.format_name = arg
yield (None, None, None, f"Changed redirect format to {arg}")
yield SQLResult(status=f"Changed redirect format to {arg}")
except ValueError:
msg = f"Redirect format {arg} not recognized. Allowed formats:"
for table_type in self.redirect_formatter.supported_formats:
msg += f"\n\t{table_type}"
yield (None, None, None, msg)
yield SQLResult(status=msg)

def change_db(self, arg: str, **_) -> Generator[tuple, None, None]:
def change_db(self, arg: str, **_) -> Generator[SQLResult, None, None]:
if arg.startswith("`") and arg.endswith("`"):
arg = re.sub(r"^`(.*)`$", r"\1", arg)
arg = re.sub(r"``", r"`", arg)
Expand All @@ -333,40 +334,35 @@ def change_db(self, arg: str, **_) -> Generator[tuple, None, None]:
self.sqlexecute.change_db(arg)
msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'

yield (
None,
None,
None,
msg,
)
yield SQLResult(status=msg)

def execute_from_file(self, arg: str, **_) -> Iterable[tuple]:
def execute_from_file(self, arg: str, **_) -> Iterable[SQLResult]:
if not arg:
message = "Missing required argument: filename."
return [(None, None, None, message)]
return [SQLResult(status=message)]
try:
with open(os.path.expanduser(arg)) as f:
query = f.read()
except IOError as e:
return [(None, None, None, str(e))]
return [SQLResult(status=str(e))]

if self.destructive_warning and confirm_destructive_query(query) is False:
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
return [SQLResult(status=message)]

assert isinstance(self.sqlexecute, SQLExecute)
return self.sqlexecute.run(query)

def change_prompt_format(self, arg: str, **_) -> list[tuple]:
def change_prompt_format(self, arg: str, **_) -> list[SQLResult]:
"""
Change the prompt format.
"""
if not arg:
message = "Missing required argument, format."
return [(None, None, None, message)]
return [SQLResult(status=message)]

self.prompt_format = self.get_prompt(arg)
return [(None, None, None, f"Changed prompt format to {arg}")]
return [SQLResult(status=f"Changed prompt format to {arg}")]

def initialize_logging(self) -> None:
log_file = os.path.expanduser(self.config["main"]["log_file"])
Expand Down Expand Up @@ -820,7 +816,7 @@ def show_suggestion_tip() -> bool:
# mutating if any one of the component statements is mutating
mutating = False

def output_res(res: Generator[tuple], start: float) -> None:
def output_res(res: Generator[SQLResult], start: float) -> None:
nonlocal mutating
result_count = 0
for title, cur, headers, status in res:
Expand Down Expand Up @@ -1274,7 +1270,7 @@ def configure_pager(self) -> None:
if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"):
special.disable_pager()

def refresh_completions(self, reset: bool = False) -> list[tuple]:
def refresh_completions(self, reset: bool = False) -> list[SQLResult]:
if reset:
with self._completer_lock:
self.completer.reset_completions()
Expand All @@ -1289,7 +1285,7 @@ def refresh_completions(self, reset: bool = False) -> list[tuple]:
},
)

return [(None, None, None, "Auto-completion refresh started in the background.")]
return [SQLResult(status="Auto-completion refresh started in the background.")]

def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None:
"""Swap the completer object in cli with the newly created completer."""
Expand Down
18 changes: 9 additions & 9 deletions mycli/packages/special/dbcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mycli.packages.special import iocommands
from mycli.packages.special.main import ArgType, special_command
from mycli.packages.special.utils import format_uptime
from mycli.packages.sqlresult import SQLResult

logger = logging.getLogger(__name__)

Expand All @@ -19,19 +20,18 @@ def list_tables(
arg: str | None = None,
_arg_type: ArgType = ArgType.PARSED_QUERY,
verbose: bool = False,
) -> list[tuple]:
) -> list[SQLResult]:
if arg:
query = f'SHOW FIELDS FROM {arg}'
else:
query = "SHOW TABLES"
logger.debug(query)
cur.execute(query)
tables = cur.fetchall()
status = ""
if cur.description:
headers = [x[0] for x in cur.description]
else:
return [(None, None, None, "")]
return [SQLResult(status="")]

if verbose and arg:
query = f'SHOW CREATE TABLE {arg}'
Expand All @@ -40,25 +40,25 @@ def list_tables(
if one := cur.fetchone():
status = one[1]

return [(None, tables, headers, status)]
return [SQLResult(results=cur, headers=headers, status=status)]


@special_command("\\l", "\\l", "List databases.", arg_type=ArgType.RAW_QUERY, case_sensitive=True)
def list_databases(cur: Cursor, **_) -> list[tuple]:
def list_databases(cur: Cursor, **_) -> list[SQLResult]:
query = "SHOW DATABASES"
logger.debug(query)
cur.execute(query)
if cur.description:
headers = [x[0] for x in cur.description]
return [(None, cur, headers, "")]
return [SQLResult(results=cur, headers=headers, status="")]
else:
return [(None, None, None, "")]
return [SQLResult(status="")]


@special_command(
"status", "\\s", "Get status information from the server.", arg_type=ArgType.RAW_QUERY, aliases=["\\s"], case_sensitive=True
)
def status(cur: Cursor, **_) -> list[tuple]:
def status(cur: Cursor, **_) -> list[SQLResult]:
query = "SHOW GLOBAL STATUS;"
logger.debug(query)
try:
Expand Down Expand Up @@ -167,4 +167,4 @@ def status(cur: Cursor, **_) -> list[tuple]:
footer.append("\n" + stats_str)

footer.append("--------------")
return [("\n".join(title), output, "", "\n".join(footer))]
return [SQLResult(title="\n".join(title), results=output, headers="", status="\n".join(footer))]
10 changes: 6 additions & 4 deletions mycli/packages/special/delimitercommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import sqlparse

from mycli.packages.sqlresult import SQLResult

sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment]
sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment]

Expand Down Expand Up @@ -58,7 +60,7 @@ def queries_iter(self, input_str: str) -> Generator[str, None, None]:
combined_statement += delimiter
queries = self._split(combined_statement)[1:]

def set(self, arg: str, **_) -> list[tuple[None, None, None, str]]:
def set(self, arg: str, **_) -> list[SQLResult]:
"""Change delimiter.

Since `arg` is everything that follows the DELIMITER token
Expand All @@ -70,14 +72,14 @@ def set(self, arg: str, **_) -> list[tuple[None, None, None, str]]:
match = arg and re.search(r"[^\s]+", arg)
if not match:
message = "Missing required argument, delimiter"
return [(None, None, None, message)]
return [SQLResult(status=message)]

delimiter = match.group()
if delimiter.lower() == "delimiter":
return [(None, None, None, 'Invalid delimiter "delimiter"')]
return [SQLResult(status='Invalid delimiter "delimiter"')]

self._delimiter = delimiter
return [(None, None, None, f'Changed delimiter to {delimiter}')]
return [SQLResult(status=f'Changed delimiter to {delimiter}')]

@property
def current(self) -> str:
Expand Down
Loading