diff --git a/changelog.md b/changelog.md index 511f2438..74a98c68 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,7 @@ TBD Features -------- +* Make password options also function as flags. Reworked password logic to prompt user as early as possible (#341). * More complete and up-to-date set of MySQL reserved words for completions. * Place exact-leading completions first. * Allow history file location to be configured. diff --git a/mycli/main.py b/mycli/main.py index a785146f..5d8faa08 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -64,7 +64,7 @@ from mycli.packages.tabular_output import sql_format from mycli.packages.toolkit.history import FileHistoryWithTimestamp from mycli.sqlcompleter import SQLCompleter -from mycli.sqlexecute import ERROR_CODE_ACCESS_DENIED, FIELD_TYPES, SQLExecute +from mycli.sqlexecute import FIELD_TYPES, SQLExecute try: import paramiko @@ -460,7 +460,7 @@ def connect( self, database: str | None = "", user: str | None = "", - passwd: str | None = "", + passwd: str | None = None, host: str | None = "", port: str | int | None = "", socket: str | None = "", @@ -528,10 +528,19 @@ def connect( # if the passwd is not specified try to set it using the password_file option password_from_file = self.get_password_from_file(password_file) passwd = passwd if isinstance(passwd, str) else password_from_file - passwd = '' if passwd is None else passwd - # Connect to the database. + # password hierarchy + # 1. -p / --pass/--password CLI options + # 2. envvar (MYSQL_PWD) + # 3. DSN (mysql://user:password) + # 4. cnf (.my.cnf / etc) + # 5. --password-file CLI option + + # if no password was found from all of the above sources, ask for a password + if passwd is None: + passwd = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) + # Connect to the database. def _connect() -> None: try: self.sqlexecute = SQLExecute( @@ -552,31 +561,7 @@ def _connect() -> None: init_command, ) except pymysql.OperationalError as e1: - if e1.args[0] == ERROR_CODE_ACCESS_DENIED: - if password_from_file is not None: - new_passwd = password_from_file - else: - new_passwd = click.prompt( - f"Password for {user}", hide_input=True, show_default=False, default='', type=str, err=True - ) - self.sqlexecute = SQLExecute( - database, - user, - new_passwd, - host, - int_port, - socket, - charset, - use_local_infile, - ssl_config_or_none, - ssh_user, - ssh_host, - int(ssh_port) if ssh_port else None, - ssh_password, - ssh_key_filename, - init_command, - ) - elif e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": + if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": try: self.sqlexecute = SQLExecute( database, @@ -595,33 +580,8 @@ def _connect() -> None: ssh_key_filename, init_command, ) - except pymysql.OperationalError as e2: - if e2.args[0] == ERROR_CODE_ACCESS_DENIED: - if password_from_file is not None: - new_passwd = password_from_file - else: - new_passwd = click.prompt( - f"Password for {user}", hide_input=True, show_default=False, default='', type=str, err=True - ) - self.sqlexecute = SQLExecute( - database, - user, - new_passwd, - host, - int_port, - socket, - charset, - use_local_infile, - None, - ssh_user, - ssh_host, - int(ssh_port) if ssh_port else None, - ssh_password, - ssh_key_filename, - init_command, - ) - else: - raise e2 + except Exception as e2: + raise e2 else: raise e1 @@ -1492,8 +1452,16 @@ def get_last_query(self) -> str | None: @click.option("-P", "--port", envvar="MYSQL_TCP_PORT", type=int, help="Port number to use for connection. Honors $MYSQL_TCP_PORT.") @click.option("-u", "--user", help="User name to connect to the database.") @click.option("-S", "--socket", envvar="MYSQL_UNIX_PORT", help="The socket file to use for connection.") -@click.option("-p", "--password", "password", envvar="MYSQL_PWD", type=str, help="Password to connect to the database.") -@click.option("--pass", "password", envvar="MYSQL_PWD", type=str, help="Password to connect to the database.") +@click.option( + "-p", + "--pass", + "--password", + "password", + is_flag=False, + flag_value="MYCLI_ASK_PASSWORD", + type=str, + help="Prompt for (or enter in cleartext) password to connect to the database.", +) @click.option("--ssh-user", help="User name to connect to ssh server.") @click.option("--ssh-host", help="Host name to connect to ssh server.") @click.option("--ssh-port", default=22, help="Port to connect to ssh server.") @@ -1553,9 +1521,11 @@ def get_last_query(self) -> str | None: @click.option( "--password-file", type=click.Path(), help="File or FIFO path containing the password to connect to the db if not specified otherwise." ) -@click.argument("database", default="", nargs=1) +@click.argument("database", default=None, nargs=1) +@click.pass_context def cli( - database: str, + ctx: click.Context, + database: str | None, user: str | None, host: str | None, port: int | None, @@ -1608,6 +1578,20 @@ def cli( - mycli mysql://my_user@my_host.com:3306/my_database """ + # if user passes the --p* flag, ask for the password right away + # to reduce lag as much as possible + if password == "MYCLI_ASK_PASSWORD": + password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) + # if the password value looks like a DSN, treat it as such and + # prompt for password + elif database is None and password is not None and password.startswith("mysql://"): + database = password + password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) + # getting the envvar ourselves because the envvar from a click + # option cannot be an empty string, but a password can be + elif password is None and os.environ.get("MYSQL_PWD") is not None: + password = os.environ.get("MYSQL_PWD") + mycli = MyCli( prompt=prompt, logfile=logfile, diff --git a/test/test_main.py b/test/test_main.py index 4f22a208..fec23cb9 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -47,7 +47,7 @@ def test_ssl_mode_on(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert ssl_cipher @@ -58,7 +58,7 @@ def test_ssl_mode_auto(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert ssl_cipher @@ -69,7 +69,7 @@ def test_ssl_mode_off(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert not ssl_cipher @@ -80,7 +80,7 @@ def test_ssl_mode_overrides_ssl(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--ssl"], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert not ssl_cipher @@ -91,7 +91,7 @@ def test_ssl_mode_overrides_no_ssl(executor, capsys): sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--no-ssl"], input=sql) result_dict = next(csv.DictReader(result.stdout.split("\n"))) - ssl_cipher = result_dict["VARIABLE_VALUE"] + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) assert ssl_cipher