Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
Upcoming Release (TBD)
======================

Features
--------

* DSN specific init-command in myclirc. Fixes (#1195)



1.29.2 (2024/12/11)
===================

Expand Down
40 changes: 36 additions & 4 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import stat
from collections import namedtuple

from pygments.lexer import combined

try:
from pwd import getpwuid
except ImportError:
Expand Down Expand Up @@ -1262,9 +1264,13 @@ def cli(

dsn_uri = None

# Treat the database argument as a DSN alias if we're missing
# other connection information.
if mycli.config["alias_dsn"] and database and "://" not in database and not any([user, password, host, port, login_path]):
# Treat the database argument as a DSN alias only if it matches a configured alias
if (
database
and "://" not in database
and not any([user, password, host, port, login_path])
and database in mycli.config.get("alias_dsn", {})
):
dsn, database = database, ""

if database and "://" in database:
Expand Down Expand Up @@ -1306,6 +1312,29 @@ def cli(
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get("identityfile", [None])[0]

ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
# Merge init-commands: global, DSN-specific, then CLI
init_cmds = []
# 1) Global init-commands
global_section = mycli.config.get("init-commands", {})
for _, val in global_section.items():
if isinstance(val, (list, tuple)):
init_cmds.extend(val)
elif val:
init_cmds.append(val)
# 2) DSN-specific init-commands
if dsn:
alias_section = mycli.config.get("alias_dsn.init-commands", {})
if dsn in alias_section:
val = alias_section.get(dsn)
if isinstance(val, (list, tuple)):
init_cmds.extend(val)
elif val:
init_cmds.append(val)
# 3) CLI-provided init_command
if init_command:
init_cmds.append(init_command)

combined_init_cmd = "; ".join(cmd.strip() for cmd in init_cmds if cmd)

mycli.connect(
database=database,
Expand All @@ -1321,11 +1350,14 @@ def cli(
ssh_port=ssh_port,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename,
init_command=init_command,
init_command=combined_init_cmd,
charset=charset,
password_file=password_file,
)

if combined_init_cmd:
click.echo("Executing init-command: %s" % combined_init_cmd, err=True)

mycli.logger.debug("Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", database, user, host, port)

# --execute argument
Expand Down
13 changes: 13 additions & 0 deletions mycli/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,22 @@ output.null = "#808080"
# sql.whitespace = ''

# Favorite queries.
# You can add your favorite queries here. They will be available in the
# REPL when you type `\f` or `\f <query_name>`.
[favorite_queries]
# example = "SELECT * FROM example_table WHERE id = 1"

# Initial commands to execute when connecting to any database.
[init-commands]
# read_only = "SET SESSION TRANSACTION READ ONLY"


# Use the -d option to reference a DSN.
# Special characters in passwords and other strings can be escaped with URL encoding.
[alias_dsn]
# example_dsn = mysql://[user[:password]@][host][:port][/dbname]

# Initial commands to execute when connecting to a DSN alias.
[alias_dsn.init-commands]
# Define one or more SQL statements per alias (semicolon-separated).
# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'"
2 changes: 1 addition & 1 deletion mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def connect(
ssl=ssl_context,
program_name="mycli",
defer_connect=defer_connect,
init_command=init_command,
init_command=init_command or None,
)

if ssh_host:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [{ name = "Mycli Core Team", email = "mycli-dev@googlegroups.com" }]
urls = { homepage = "http://mycli.net" }

dependencies = [
"click >= 7.0",
"click >= 7.0,<8.1.8",
"cryptography >= 1.0.0",
"Pygments>=1.6",
"prompt_toolkit>=3.0.6,<4.0.0",
Expand Down
14 changes: 14 additions & 0 deletions test/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,25 @@ output.null = "#808080"
# sql.whitespace = ''

# Favorite queries.
# You can add your favorite queries here. They will be available in the
# REPL when you type `\f` or `\f <query_name>`.
[favorite_queries]
check = 'select "✔"'
foo_args = 'SELECT $1, "$2", "$3"'
# example = "SELECT * FROM example_table WHERE id = 1"

# Initial commands to execute when connecting to any database.
[init-commands]
# read_only = "SET SESSION TRANSACTION READ ONLY"
global_limit = "set sql_select_limit=9999"


# Use the -d option to reference a DSN.
# Special characters in passwords and other strings can be escaped with URL encoding.
[alias_dsn]
# example_dsn = mysql://[user[:password]@][host][:port][/dbname]

# Initial commands to execute when connecting to a DSN alias.
[alias_dsn.init-commands]
# Define one or more SQL statements per alias (semicolon-separated).
# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'"
11 changes: 11 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,14 @@ def test_init_command_multiple_arg(executor):
assert result.exit_code == 0
assert expected_sql_select_limit in result.output
assert expected_max_join_size in result.output

@dbtest
def test_global_init_commands(executor):
"""Tests that global init-commands from config are executed by default."""
# The global init-commands section in test/myclirc sets sql_select_limit=9999
sql = 'show variables like "sql_select_limit";'
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = "sql_select_limit\t9999\n"
assert result.exit_code == 0
assert expected in result.output