diff --git a/.gitignore b/.gitignore index 970fcd4f..1fb195db 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ .venv/ venv/ + +.myclirc +uv.lock diff --git a/changelog.md b/changelog.md index 29776b9a..f8b5a279 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,8 @@ Upcoming (TBD) Features -------- * Update query processing functions to allow automatic show_warnings to work for more code paths like DDL. +* Add new ssl_mode config / --ssl-mode CLI option to control SSL connection behavior. This setting will supercede the + existing --ssl/--no-ssl CLI options, which are deprecated and will be removed in a future release. * Rework reconnect logic to actually reconnect or create a new connection instead of simply changing the database (#746). diff --git a/mycli/main.py b/mycli/main.py index d062e05a..8a35de3c 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -38,6 +38,7 @@ from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.shortcuts import CompleteStyle, PromptSession import pymysql +from pymysql.constants.ER import HANDSHAKE_ERROR from pymysql.cursors import Cursor import sqlglot import sqlparse @@ -154,6 +155,14 @@ def __init__( self.login_path_as_host = c["main"].as_bool("login_path_as_host") self.post_redirect_command = c['main'].get('post_redirect_command') + # set ssl_mode if a valid option is provided in a config file, otherwise None + ssl_mode = c["main"].get("ssl_mode", None) + if ssl_mode not in ("auto", "on", "off", None): + self.echo(f"Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.", err=True, fg="red") + self.ssl_mode = None + else: + self.ssl_mode = ssl_mode + # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") self.show_warnings = show_warnings or c["main"].as_bool("show_warnings") @@ -566,6 +575,24 @@ def _connect() -> None: ssh_key_filename, init_command, ) + elif e.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": + self.sqlexecute = SQLExecute( + database, + user, + passwd, + host, + int_port, + socket, + charset, + use_local_infile, + None, + ssh_user, + ssh_host, + int(ssh_port) if ssh_port else None, + ssh_password, + ssh_key_filename, + init_command, + ) else: raise e @@ -1398,7 +1425,13 @@ def get_last_query(self) -> str | None: @click.option("--ssh-key-filename", help="Private key filename (identify file) for the ssh connection.") @click.option("--ssh-config-path", help="Path to ssh configuration.", default=os.path.expanduser("~") + "/.ssh/config") @click.option("--ssh-config-host", help="Host to connect to ssh server reading from ssh configuration.") -@click.option("--ssl", "ssl_enable", is_flag=True, help="Enable SSL for connection (automatically enabled with other flags).") +@click.option( + "--ssl-mode", + "ssl_mode", + help="Set desired SSL behavior. auto=preferred, on=required, off=off.", + type=click.Choice(["auto", "on", "off"]), +) +@click.option("--ssl/--no-ssl", "ssl_enable", default=None, help="Enable SSL for connection (automatically enabled with other flags).") @click.option("--ssl-ca", help="CA file in PEM format.", type=click.Path(exists=True)) @click.option("--ssl-capath", help="CA directory.") @click.option("--ssl-cert", help="X509 cert in PEM format.", type=click.Path(exists=True)) @@ -1414,8 +1447,6 @@ def get_last_query(self) -> str | None: is_flag=True, help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""), ) -# as of 2016-02-15 revocation list is not supported by underling PyMySQL -# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client) @click.version_option(__version__, "-V", "--version", help="Output mycli's version.") @click.option("-v", "--verbose", is_flag=True, help="Verbose output.") @click.option("-D", "--database", "dbname", help="Database to use.") @@ -1464,6 +1495,7 @@ def cli( auto_vertical_output: bool, show_warnings: bool, local_infile: bool, + ssl_mode: str | None, ssl_enable: bool, ssl_ca: str | None, ssl_capath: str | None, @@ -1510,6 +1542,15 @@ def cli( warn=warn, myclirc=myclirc, ) + + if ssl_enable is not None: + click.secho( + "Warning: The --ssl/--no-ssl CLI options are deprecated and will be removed in a future release. " + "Please use the ssl_mode config or --ssl-mode CLI options instead.", + err=True, + fg="yellow", + ) + if list_dsn: try: alias_dsn = mycli.config["alias_dsn"] @@ -1606,19 +1647,36 @@ def cli( ssl_verify_server_cert = ssl_verify_server_cert or (params[0].lower() == 'true') ssl_enable = True - ssl = { - "enable": ssl_enable, - "ca": ssl_ca and os.path.expanduser(ssl_ca), - "cert": ssl_cert and os.path.expanduser(ssl_cert), - "key": ssl_key and os.path.expanduser(ssl_key), - "capath": ssl_capath, - "cipher": ssl_cipher, - "tls_version": tls_version, - "check_hostname": ssl_verify_server_cert, - } - - # remove empty ssl options - ssl = {k: v for k, v in ssl.items() if v is not None} + ssl_mode = ssl_mode or mycli.ssl_mode # cli option or config option + + # if there is a mismatch between the ssl_mode value and other sources of ssl config, show a warning + # specifically using "is False" to not pickup the case where ssl_enable is None (not set by the user) + if ssl_enable and ssl_mode == "off" or ssl_enable is False and ssl_mode in ("auto", "on"): + click.secho( + f"Warning: The current ssl_mode value of '{ssl_mode}' is overriding the value provided by " + f"either the --ssl/--no-ssl CLI options or a DSN URI parameter (ssl={ssl_enable}).", + err=True, + fg="yellow", + ) + + # configure SSL if ssl_mode is auto/on or if + # ssl_enable = True (from --ssl or a DSN URI) and ssl_mode is None + if ssl_mode in ("auto", "on") or (ssl_enable and ssl_mode is None): + ssl = { + "mode": ssl_mode, + "enable": ssl_enable, + "ca": ssl_ca and os.path.expanduser(ssl_ca), + "cert": ssl_cert and os.path.expanduser(ssl_cert), + "key": ssl_key and os.path.expanduser(ssl_key), + "capath": ssl_capath, + "cipher": ssl_cipher, + "tls_version": tls_version, + "check_hostname": ssl_verify_server_cert, + } + # remove empty ssl options + ssl = {k: v for k, v in ssl.items() if v is not None} + else: + ssl = None if ssh_config_host: ssh_config = read_ssh_config(ssh_config_path).lookup(ssh_config_host) diff --git a/mycli/myclirc b/mycli/myclirc index a9e15808..872a904f 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -5,6 +5,13 @@ # after executing a SQL statement when applicable. show_warnings = False +# Sets the desired behavior for handling secure connections to the database server. +# Possible values: +# auto = SSL is preferred. Will attempt to connect via SSL, but will fallback to cleartext as needed. +# on = SSL is required. Will attempt to connect via SSL and will fail if a secure connection is not established. +# off = do not use SSL. Will fail if the server requires a secure connection. +ssl_mode = auto + # Enables context sensitive auto-completion. If this is disabled the all # possible completions will be listed. smart_completion = True diff --git a/test/features/db_utils.py b/test/features/db_utils.py index 5c81b661..0d50ab63 100644 --- a/test/features/db_utils.py +++ b/test/features/db_utils.py @@ -40,7 +40,13 @@ def create_cn(hostname, port, password, username, dbname): """ cn = pymysql.connect( - host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor + host=hostname, + port=port, + user=username, + password=password, + db=dbname, + charset="utf8mb4", + cursorclass=pymysql.cursors.DictCursor, ) return cn @@ -57,7 +63,13 @@ def drop_db(hostname="localhost", port=3306, username=None, password=None, dbnam """ cn = pymysql.connect( - host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor + host=hostname, + port=port, + user=username, + password=password, + db=dbname, + charset="utf8mb4", + cursorclass=pymysql.cursors.DictCursor, ) with cn.cursor() as cr: diff --git a/test/myclirc b/test/myclirc index a19a34ba..8c9e807e 100644 --- a/test/myclirc +++ b/test/myclirc @@ -5,6 +5,13 @@ # after executing a SQL statement when applicable. show_warnings = False +# Sets the desired behavior for handling secure connections to the database server. +# Possible values: +# auto = SSL is preferred. Will attempt to connect via SSL, but will fallback to cleartext as needed. +# on = SSL is required. Will attempt to connect via SSL and will fail if a secure connection is not established. +# off = do not use SSL. Will fail if the server requires a secure connection. +ssl_mode = auto + # Enables context sensitive auto-completion. If this is disabled the all # possible completions will be listed. smart_completion = True diff --git a/test/test_main.py b/test/test_main.py index 04ac5c18..909508bb 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,6 +1,7 @@ # type: ignore from collections import namedtuple +import csv import os import shutil from tempfile import NamedTemporaryFile @@ -38,6 +39,61 @@ ] +@dbtest +def test_ssl_mode_on(executor, capsys): + runner = CliRunner() + ssl_mode = "on" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict["VARIABLE_VALUE"] + assert ssl_cipher + + +@dbtest +def test_ssl_mode_auto(executor, capsys): + runner = CliRunner() + ssl_mode = "auto" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict["VARIABLE_VALUE"] + assert ssl_cipher + + +@dbtest +def test_ssl_mode_off(executor, capsys): + runner = CliRunner() + ssl_mode = "off" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict["VARIABLE_VALUE"] + assert not ssl_cipher + + +@dbtest +def test_ssl_mode_overrides_ssl(executor, capsys): + runner = CliRunner() + ssl_mode = "off" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--ssl"], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict["VARIABLE_VALUE"] + assert not ssl_cipher + + +@dbtest +def test_ssl_mode_overrides_no_ssl(executor, capsys): + runner = CliRunner() + ssl_mode = "on" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(cli, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--no-ssl"], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict["VARIABLE_VALUE"] + assert ssl_cipher + + @dbtest def test_reconnect_no_database(executor, capsys): m = MyCli() @@ -509,6 +565,7 @@ def __init__(self, **args): self.destructive_warning = False self.main_formatter = Formatter() self.redirect_formatter = Formatter() + self.ssl_mode = "auto" def connect(self, **args): MockMyCli.connect_args = args @@ -673,6 +730,7 @@ def __init__(self, **args): self.destructive_warning = False self.main_formatter = Formatter() self.redirect_formatter = Formatter() + self.ssl_mode = "auto" def connect(self, **args): MockMyCli.connect_args = args