Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@

.venv/
venv/

.myclirc
uv.lock
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Upcoming (TBD)
Features
--------
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL.
* Add new ssl_mode config / --ssl-mode CLI option to control SSL connection behavior. This setting will supercede the
existing --ssl/--no-ssl CLI options, which will be deprecated in a later release.
* Rework reconnect logic to actually reconnect or create a new connection instead of simply changing the database (#746).


Expand Down
90 changes: 74 additions & 16 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
import pymysql
from pymysql.constants.ER import HANDSHAKE_ERROR
from pymysql.cursors import Cursor
import sqlglot
import sqlparse
Expand Down Expand Up @@ -154,6 +155,14 @@ def __init__(
self.login_path_as_host = c["main"].as_bool("login_path_as_host")
self.post_redirect_command = c['main'].get('post_redirect_command')

# set ssl_mode if a valid option is provided in a config file, otherwise None
ssl_mode = c["ssl"].get("ssl_mode", None)
if ssl_mode not in ("auto", "on", "off", None):
self.echo(f"Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.", err=True, fg="red")
self.ssl_mode = None
else:
self.ssl_mode = ssl_mode

# read from cli argument or user config file
self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output")
self.show_warnings = show_warnings or c["main"].as_bool("show_warnings")
Expand Down Expand Up @@ -566,6 +575,24 @@ def _connect() -> None:
ssh_key_filename,
init_command,
)
elif e.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto":
self.sqlexecute = SQLExecute(
database,
user,
passwd,
host,
int_port,
socket,
charset,
use_local_infile,
None,
ssh_user,
ssh_host,
int(ssh_port) if ssh_port else None,
ssh_password,
ssh_key_filename,
init_command,
)
else:
raise e

Expand Down Expand Up @@ -1398,7 +1425,13 @@ def get_last_query(self) -> str | None:
@click.option("--ssh-key-filename", help="Private key filename (identify file) for the ssh connection.")
@click.option("--ssh-config-path", help="Path to ssh configuration.", default=os.path.expanduser("~") + "/.ssh/config")
@click.option("--ssh-config-host", help="Host to connect to ssh server reading from ssh configuration.")
@click.option("--ssl", "ssl_enable", is_flag=True, help="Enable SSL for connection (automatically enabled with other flags).")
@click.option(
"--ssl-mode",
"ssl_mode",
help="Set desired SSL behavior. auto=preferred, on=required, off=off.",
type=click.Choice(["auto", "on", "off"]),
)
@click.option("--ssl/--no-ssl", "ssl_enable", default=None, help="Enable SSL for connection (automatically enabled with other flags).")
@click.option("--ssl-ca", help="CA file in PEM format.", type=click.Path(exists=True))
@click.option("--ssl-capath", help="CA directory.")
@click.option("--ssl-cert", help="X509 cert in PEM format.", type=click.Path(exists=True))
Expand All @@ -1414,8 +1447,6 @@ def get_last_query(self) -> str | None:
is_flag=True,
help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""),
)
# as of 2016-02-15 revocation list is not supported by underling PyMySQL
# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client)
@click.version_option(__version__, "-V", "--version", help="Output mycli's version.")
@click.option("-v", "--verbose", is_flag=True, help="Verbose output.")
@click.option("-D", "--database", "dbname", help="Database to use.")
Expand Down Expand Up @@ -1464,6 +1495,7 @@ def cli(
auto_vertical_output: bool,
show_warnings: bool,
local_infile: bool,
ssl_mode: str | None,
ssl_enable: bool,
ssl_ca: str | None,
ssl_capath: str | None,
Expand Down Expand Up @@ -1510,6 +1542,15 @@ def cli(
warn=warn,
myclirc=myclirc,
)

if ssl_enable is not None:
click.secho(
"Warning: The --ssl/--no-ssl CLI options will be deprecated in a future release. "
"Please use the ssl_mode config or --ssl-mode CLI options instead.",
err=True,
fg="yellow",
)

if list_dsn:
try:
alias_dsn = mycli.config["alias_dsn"]
Expand Down Expand Up @@ -1606,19 +1647,36 @@ def cli(
ssl_verify_server_cert = ssl_verify_server_cert or (params[0].lower() == 'true')
ssl_enable = True

ssl = {
"enable": ssl_enable,
"ca": ssl_ca and os.path.expanduser(ssl_ca),
"cert": ssl_cert and os.path.expanduser(ssl_cert),
"key": ssl_key and os.path.expanduser(ssl_key),
"capath": ssl_capath,
"cipher": ssl_cipher,
"tls_version": tls_version,
"check_hostname": ssl_verify_server_cert,
}

# remove empty ssl options
ssl = {k: v for k, v in ssl.items() if v is not None}
ssl_mode = ssl_mode or mycli.ssl_mode # cli option or config option

# if there is a mismatch between the ssl_mode value and other sources of ssl config, show a warning
# specifically using "is False" to not pickup the case where ssl_enable is None (not set by the user)
if ssl_enable and ssl_mode == "off" or ssl_enable is False and ssl_mode in ("auto", "on"):
click.secho(
f"Warning: The current ssl_mode value of '{ssl_mode}' is overriding the value provided by "
f"either the --ssl/--no-ssl CLI options or a DSN URI parameter (ssl={ssl_enable}).",
err=True,
fg="yellow",
)

# configure SSL if ssl_mode is auto/on or if
# ssl_enable = True (from --ssl or a DSN URI) and ssl_mode is None
if ssl_mode in ("auto", "on") or (ssl_enable and ssl_mode is None):
ssl = {
"mode": ssl_mode,
"enable": ssl_enable,
"ca": ssl_ca and os.path.expanduser(ssl_ca),
"cert": ssl_cert and os.path.expanduser(ssl_cert),
"key": ssl_key and os.path.expanduser(ssl_key),
"capath": ssl_capath,
"cipher": ssl_cipher,
"tls_version": tls_version,
"check_hostname": ssl_verify_server_cert,
}
# remove empty ssl options
ssl = {k: v for k, v in ssl.items() if v is not None}
else:
ssl = None

if ssh_config_host:
ssh_config = read_ssh_config(ssh_config_path).lookup(ssh_config_host)
Expand Down
8 changes: 8 additions & 0 deletions mycli/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,11 @@ output.null = "#808080"
[alias_dsn.init-commands]
# Define one or more SQL statements per alias (semicolon-separated).
# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'"

[ssl]
# Sets the desired behavior for handling secure connections to the database server.
# Possible values:
# auto = SSL is preferred. Will attempt to connect via SSL, but will fallback to cleartext as needed.
# on = SSL is required. Will attempt to connect via SSL and will fail if a secure connection is not established.
# off = do not use SSL. Will fail if the server requires a secure connection.
ssl_mode = auto
16 changes: 14 additions & 2 deletions test/features/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def create_cn(hostname, port, password, username, dbname):

"""
cn = pymysql.connect(
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
)

return cn
Expand All @@ -57,7 +63,13 @@ def drop_db(hostname="localhost", port=3306, username=None, password=None, dbnam

"""
cn = pymysql.connect(
host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
host=hostname,
port=port,
user=username,
password=password,
db=dbname,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
)

with cn.cursor() as cr:
Expand Down
8 changes: 8 additions & 0 deletions test/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,11 @@ global_limit = set sql_select_limit=9999
[alias_dsn.init-commands]
# Define one or more SQL statements per alias (semicolon-separated).
# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'"

[ssl]
# Sets the desired behavior for handling secure connections to the database server.
# Possible values:
# auto = SSL is preferred. Will attempt to connect via SSL, but will fallback to cleartext as needed.
# on = SSL is required. Will attempt to connect via SSL and will fail if a secure connection is not established.
# off = do not use SSL. Will fail if the server requires a secure connection.
ssl_mode = auto
58 changes: 58 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore

from collections import namedtuple
import csv
import os
import shutil
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -38,6 +39,61 @@
]


@dbtest
def test_ssl_mode_on(executor, capsys):
runner = CliRunner()
ssl_mode = "on"
sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'"
result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql)
result_dict = next(csv.DictReader(result.stdout.split("\n")))
ssl_cipher = result_dict["VARIABLE_VALUE"]
assert ssl_cipher


@dbtest
def test_ssl_mode_auto(executor, capsys):
runner = CliRunner()
ssl_mode = "auto"
sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'"
result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql)
result_dict = next(csv.DictReader(result.stdout.split("\n")))
ssl_cipher = result_dict["VARIABLE_VALUE"]
assert ssl_cipher


@dbtest
def test_ssl_mode_off(executor, capsys):
runner = CliRunner()
ssl_mode = "off"
sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'"
result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql)
result_dict = next(csv.DictReader(result.stdout.split("\n")))
ssl_cipher = result_dict["VARIABLE_VALUE"]
assert not ssl_cipher


@dbtest
def test_ssl_mode_overrides_ssl(executor, capsys):
runner = CliRunner()
ssl_mode = "off"
sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'"
result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--ssl"], input=sql)
result_dict = next(csv.DictReader(result.stdout.split("\n")))
ssl_cipher = result_dict["VARIABLE_VALUE"]
assert not ssl_cipher


@dbtest
def test_ssl_mode_overrides_no_ssl(executor, capsys):
runner = CliRunner()
ssl_mode = "on"
sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'"
result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--no-ssl"], input=sql)
result_dict = next(csv.DictReader(result.stdout.split("\n")))
ssl_cipher = result_dict["VARIABLE_VALUE"]
assert ssl_cipher


@dbtest
def test_reconnect_no_database(executor, capsys):
m = MyCli()
Expand Down Expand Up @@ -509,6 +565,7 @@ def __init__(self, **args):
self.destructive_warning = False
self.main_formatter = Formatter()
self.redirect_formatter = Formatter()
self.ssl_mode = "auto"

def connect(self, **args):
MockMyCli.connect_args = args
Expand Down Expand Up @@ -673,6 +730,7 @@ def __init__(self, **args):
self.destructive_warning = False
self.main_formatter = Formatter()
self.redirect_formatter = Formatter()
self.ssl_mode = "auto"

def connect(self, **args):
MockMyCli.connect_args = args
Expand Down