diff --git a/mycli/main.py b/mycli/main.py index 5d8faa08..f00af3af 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -56,7 +56,7 @@ from mycli.packages import special from mycli.packages.filepaths import dir_path_exists, guess_socket_location from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command -from mycli.packages.parseutils import is_destructive, is_dropping_database +from mycli.packages.parseutils import is_destructive, is_dropping_database, is_valid_connection_scheme from mycli.packages.prompt_utils import confirm, confirm_destructive_query from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import ArgType @@ -1584,7 +1584,14 @@ def cli( password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) # if the password value looks like a DSN, treat it as such and # prompt for password - elif database is None and password is not None and password.startswith("mysql://"): + elif database is None and password is not None and "://" in password: + # check if the scheme is valid. We do not actually have any logic for these, but + # it will most usefully catch the case where we erroneously catch someone's + # password, and give them an easy error message to follow / report + is_valid_scheme, scheme = is_valid_connection_scheme(password) + if not is_valid_scheme: + click.secho(f"Error: Unknown connection scheme provided for DSN URI ({scheme}://)", err=True, fg="red") + sys.exit(1) database = password password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True) # getting the envvar ourselves because the envvar from a click diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index b29e7cbd..c47f9472 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -23,6 +23,17 @@ } +def is_valid_connection_scheme(text: str) -> tuple[bool, str | None]: + # exit early if the text does not resemble a DSN URI + if "://" not in text: + return False, None + scheme = text.split("://")[0] + if scheme not in ("mysql", "mysqlx", "tcp", "socket", "ssh"): + return False, scheme + else: + return True, None + + def last_word(text: str, include: str = "alphanum_underscore") -> str: r""" Find the last word in a sentence. diff --git a/test/test_main.py b/test/test_main.py index fec23cb9..ebbed6c7 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -11,7 +11,7 @@ from click.testing import CliRunner from pymysql.err import OperationalError -from mycli.main import MyCli, cli, thanks_picker +from mycli.main import MyCli, cli, is_valid_connection_scheme, thanks_picker import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.sqlexecute import ServerInfo, SQLExecute @@ -40,6 +40,16 @@ ] +def test_is_valid_connection_scheme_valid(executor, capsys): + is_valid, scheme = is_valid_connection_scheme("mysql://test@localhost:3306/dev") + assert is_valid + + +def test_is_valid_connection_scheme_invalid(executor, capsys): + is_valid, scheme = is_valid_connection_scheme("nope://test@localhost:3306/dev") + assert not is_valid + + @dbtest def test_ssl_mode_on(executor, capsys): runner = CliRunner()