From 2c0f2aac5cfe2dad7c54fb318c1ca70d85aaa352 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 28 Feb 2026 08:05:08 -0500 Subject: [PATCH] allow styling of status, timing, and warnings text per settings in ~/.myclirc. The new styles are off by default, just available to change. Warnings styles represent a set of independent header, rows, etc. for the whole warnings table. It is not yet possible to style the borders of the warnings table. For consistency, timings were added to warnings, which previously were not shown. This requires updating some tests to use a different method for capturing the standard output. Otherwise we get an error from deep within prompt_toolkit. --- changelog.md | 1 + mycli/clistyle.py | 25 ++++++++++++++++-- mycli/main.py | 60 +++++++++++++++++++++++++++++++++++-------- mycli/myclirc | 8 ++++++ test/myclirc | 8 ++++++ test/test_clistyle.py | 10 ++++---- test/test_main.py | 22 ++++++++++------ 7 files changed, 108 insertions(+), 26 deletions(-) diff --git a/changelog.md b/changelog.md index e7708833..b5901554 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ Upcoming (TBD) Features --------- * Offer filename completions on more special commands, such as `\edit`. +* Allow styling of status, timing, and warnings text. Bug Fixes diff --git a/mycli/clistyle.py b/mycli/clistyle.py index 9e860924..9f6d21c4 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -36,6 +36,14 @@ Token.Output.OddRow: "output.odd-row", Token.Output.EvenRow: "output.even-row", Token.Output.Null: "output.null", + Token.Output.Status: "output.status", + Token.Output.Timing: "output.timing", + Token.Warnings.Header: "warnings.header", + Token.Warnings.OddRow: "warnings.odd-row", + Token.Warnings.EvenRow: "warnings.even-row", + Token.Warnings.Null: "warnings.null", + Token.Warnings.Status: "warnings.status", + Token.Warnings.Timing: "warnings.timing", Token.Prompt: "prompt", Token.Continuation: "continuation", } @@ -96,7 +104,7 @@ def parse_pygments_style( return token_type, style_dict[token_name] -def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle: +def style_factory_toolkit(name: str, cli_style: dict[str, str]) -> _MergedStyle: try: style: PygmentsStyle = pygments.styles.get_style_by_name(name) except ClassNotFound: @@ -124,7 +132,11 @@ def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle: return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) -def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle: +def style_factory_helpers( + name: str, + cli_style: dict[str, str], + warnings: bool = False, +) -> PygmentsStyle: try: style: dict[PygmentsStyle | str, str] = pygments.styles.get_style_by_name(name).styles except ClassNotFound: @@ -144,6 +156,15 @@ def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle: # TODO: cli helpers will have to switch to ptk.Style logger.error("Unhandled style / class name: %s", token) + if warnings: + for warning_token in style: + if 'Warnings' not in str(warning_token): + continue + warning_str = str(warning_token) + output_str = warning_str.replace('Warnings', 'Output') + output_token = string_to_tokentype(output_str) + style[output_token] = style[warning_token] + class OutputStyle(PygmentsStyle): default_style = "" styles = style diff --git a/mycli/main.py b/mycli/main.py index e6a22145..3d0a5b0f 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -32,13 +32,21 @@ import click from configobj import ConfigObj import keyring +from prompt_toolkit import print_formatted_text from prompt_toolkit.application.current import get_app from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import Completion, DynamicCompleter from prompt_toolkit.document import Document from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.filters import Condition, HasFocus, IsDone -from prompt_toolkit.formatted_text import ANSI, AnyFormattedText +from prompt_toolkit.formatted_text import ( + ANSI, + HTML, + AnyFormattedText, + FormattedText, + to_formatted_text, + to_plain_text, +) from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register from prompt_toolkit.key_binding.key_processor import KeyPressEvent from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor @@ -54,7 +62,7 @@ from mycli import __version__ from mycli.clibuffer import cli_is_multiline -from mycli.clistyle import style_factory, style_factory_output +from mycli.clistyle import style_factory_helpers, style_factory_toolkit from mycli.clitoolbar import create_toolbar_tokens_func from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher @@ -206,7 +214,9 @@ def __init__( self.syntax_style = c["main"]["syntax_style"] self.less_chatty = c["main"].as_bool("less_chatty") self.cli_style = c["colors"] - self.output_style = style_factory_output(self.syntax_style, self.cli_style) + self.toolkit_style = style_factory_toolkit(self.syntax_style, self.cli_style) + self.helpers_style = style_factory_helpers(self.syntax_style, self.cli_style) + self.helpers_warnings_style = style_factory_helpers(self.syntax_style, self.cli_style, warnings=True) self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn @@ -880,6 +890,13 @@ def handle_unprettify_binding(self, text: str) -> str: unpretty_text = unpretty_text + ';' return unpretty_text + def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: + self.log_output(timing) + add_style = 'class:warnings.timing' if is_warnings_style else 'class:output.timing' + formatted_timing = FormattedText([('', timing)]) + styled_timing = to_formatted_text(formatted_timing, style=add_style) + print_formatted_text(styled_timing, style=self.toolkit_style) + def run_cli(self) -> None: iterations = 0 sqlexecute = self.sqlexecute @@ -1001,7 +1018,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None: assert self.prompt_app is not None self.prompt_app.output.bell() if special.is_timing_enabled(): - self.echo(f"Time: {t:0.03f}s") + self.output_timing(f"Time: {t:0.03f}s") except KeyboardInterrupt: pass @@ -1012,7 +1029,10 @@ def output_res(results: Generator[SQLResult], start: float) -> None: # get and display warnings if enabled if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: warnings = sqlexecute.run("SHOW WARNINGS") + t = time() - start + saw_warning = False for warning in warnings: + saw_warning = True formatted = self.format_sqlresult( warning, is_expanded=special.is_expanded_output(), @@ -1021,9 +1041,13 @@ def output_res(results: Generator[SQLResult], start: float) -> None: numeric_alignment=self.numeric_alignment, binary_display=self.binary_display, max_width=max_width, + is_warnings_style=True, ) self.echo("") - self.output(formatted, warning.status) + self.output(formatted, warning.status, is_warnings_style=True) + + if saw_warning and special.is_timing_enabled(): + self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True) def keepalive_hook(_context): """ @@ -1105,7 +1129,7 @@ def one_iteration(text: str | None = None) -> None: click.echo(context) click.echo("---") if special.is_timing_enabled(): - click.echo(f"Time: {duration:.2f} seconds") + self.output_timing(f"Time: {duration:.2f} seconds") text = self.prompt_app.prompt( default=sql or '', inputhook=inputhook, @@ -1264,7 +1288,8 @@ def one_iteration(text: str | None = None) -> None: auto_suggest=AutoSuggestFromHistory(), complete_while_typing=complete_while_typing_filter, multiline=cli_is_multiline(self), - style=style_factory(self.syntax_style, self.cli_style), + # why not self.toolkit_style here? + style=style_factory_toolkit(self.syntax_style, self.cli_style), include_default_pygments_style=False, key_bindings=key_bindings, enable_open_in_editor=True, @@ -1344,8 +1369,10 @@ def log_query(self, query: str) -> None: self.logfile.write(query) self.logfile.write("\n") - def log_output(self, output: str) -> None: + def log_output(self, output: str | AnyFormattedText) -> None: """Log the output in the audit log, if it's enabled.""" + if isinstance(output, (ANSI, HTML, FormattedText)): + output = to_plain_text(output) if isinstance(self.logfile, TextIOWrapper): click.echo(output, file=self.logfile) @@ -1371,7 +1398,12 @@ def get_output_margin(self, status: str | None = None) -> int: return margin - def output(self, output: itertools.chain[str], status: str | None = None) -> None: + def output( + self, + output: itertools.chain[str], + status: str | None = None, + is_warnings_style: bool = False, + ) -> None: """Output text to stdout or a pager command. The status text is not outputted to pager or files. @@ -1433,8 +1465,12 @@ def newlinewrapper(text: list[str]) -> Generator[str, None, None]: click.secho(line) if status: + # todo allow status to be a FormattedText, but strip before logging self.log_output(status) - click.secho(status) + add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status' + formatted_status = FormattedText([('', status)]) + styled_status = to_formatted_text(formatted_status, style=add_style) + print_formatted_text(styled_status, style=self.toolkit_style) def configure_pager(self) -> None: # Provide sane defaults for less if they are empty. @@ -1576,6 +1612,7 @@ def run_query( null_string=self.null_string, numeric_alignment=self.numeric_alignment, binary_display=self.binary_display, + is_warnings_style=True, ) for line in output: click.echo(line, nl=new_line) @@ -1592,6 +1629,7 @@ def format_sqlresult( numeric_alignment: str = 'right', binary_display: str | None = None, max_width: int | None = None, + is_warnings_style: bool = False, ) -> itertools.chain[str]: if is_redirected: use_formatter = self.redirect_formatter @@ -1605,7 +1643,7 @@ def format_sqlresult( "dialect": "unix", "disable_numparse": True, "preserve_whitespace": True, - "style": self.output_style, + "style": self.helpers_warnings_style if is_warnings_style else self.helpers_style, } default_kwargs = use_formatter._output_formats[use_formatter.format_name].formatter_args diff --git a/mycli/myclirc b/mycli/myclirc index 6f65090a..dbcfc506 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -250,6 +250,14 @@ output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" output.null = "#808080" +output.status = "" +output.timing = "" +warnings.header = "#00ff5f bold" +warnings.odd-row = "" +warnings.even-row = "" +warnings.null = "#808080" +warnings.status = "" +warnings.timing = "" # SQL syntax highlighting overrides # sql.comment = 'italic #408080' diff --git a/test/myclirc b/test/myclirc index e69cdd8b..56b92dcb 100644 --- a/test/myclirc +++ b/test/myclirc @@ -248,6 +248,14 @@ output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" output.null = "#808080" +output.status = "" +output.timing = "" +warnings.header = "#00ff5f bold" +warnings.odd-row = "" +warnings.even-row = "" +warnings.null = "#808080" +warnings.status = "" +warnings.timing = "" # SQL syntax highlighting overrides # sql.comment = 'italic #408080' diff --git a/test/test_clistyle.py b/test/test_clistyle.py index cb6bdcb2..f6ac429d 100644 --- a/test/test_clistyle.py +++ b/test/test_clistyle.py @@ -6,15 +6,15 @@ from pygments.token import Token import pytest -from mycli.clistyle import style_factory +from mycli.clistyle import style_factory_toolkit @pytest.mark.skip(reason="incompatible with new prompt toolkit") -def test_style_factory(): +def test_style_factory_toolkit(): """Test that a Pygments Style class is created.""" header = "bold underline #ansired" cli_style = {"Token.Output.Header": header} - style = style_factory("default", cli_style) + style = style_factory_toolkit("default", cli_style) assert isinstance(style(), Style) assert Token.Output.Header in style.styles @@ -22,8 +22,8 @@ def test_style_factory(): @pytest.mark.skip(reason="incompatible with new prompt toolkit") -def test_style_factory_unknown_name(): +def test_style_factory_toolkit_unknown_name(): """Test that an unrecognized name will not throw an error.""" - style = style_factory("foobar", {}) + style = style_factory_toolkit("foobar", {}) assert isinstance(style(), Style) diff --git a/test/test_main.py b/test/test_main.py index 88e92b11..be3a32a5 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,7 +1,9 @@ # type: ignore from collections import namedtuple +from contextlib import redirect_stdout import csv +import io import os import shutil from tempfile import NamedTemporaryFile @@ -42,7 +44,7 @@ @dbtest -def test_binary_display_hex(executor, capsys): +def test_binary_display_hex(executor): m = MyCli() m.sqlexecute = SQLExecute( None, @@ -72,14 +74,16 @@ def test_binary_display_hex(executor, capsys): binary_display="hex", max_width=None, ) - m.output(formatted, sqlresult.status) + f = io.StringIO() + with redirect_stdout(f): + m.output(formatted, sqlresult.status) expected = " 0x6a " - stdout = capsys.readouterr().out - assert expected in stdout + output = f.getvalue() + assert expected in output @dbtest -def test_binary_display_utf8(executor, capsys): +def test_binary_display_utf8(executor): m = MyCli() m.sqlexecute = SQLExecute( None, @@ -109,10 +113,12 @@ def test_binary_display_utf8(executor, capsys): binary_display="utf8", max_width=None, ) - m.output(formatted, sqlresult.status) + f = io.StringIO() + with redirect_stdout(f): + m.output(formatted, sqlresult.status) expected = " j " - stdout = capsys.readouterr().out - assert expected in stdout + output = f.getvalue() + assert expected in output @dbtest