diff --git a/changelog.md b/changelog.md index e7708833..b5901554 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ Upcoming (TBD) Features --------- * Offer filename completions on more special commands, such as `\edit`. +* Allow styling of status, timing, and warnings text. Bug Fixes diff --git a/mycli/clistyle.py b/mycli/clistyle.py index 9e860924..9f6d21c4 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -36,6 +36,14 @@ Token.Output.OddRow: "output.odd-row", Token.Output.EvenRow: "output.even-row", Token.Output.Null: "output.null", + Token.Output.Status: "output.status", + Token.Output.Timing: "output.timing", + Token.Warnings.Header: "warnings.header", + Token.Warnings.OddRow: "warnings.odd-row", + Token.Warnings.EvenRow: "warnings.even-row", + Token.Warnings.Null: "warnings.null", + Token.Warnings.Status: "warnings.status", + Token.Warnings.Timing: "warnings.timing", Token.Prompt: "prompt", Token.Continuation: "continuation", } @@ -96,7 +104,7 @@ def parse_pygments_style( return token_type, style_dict[token_name] -def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle: +def style_factory_toolkit(name: str, cli_style: dict[str, str]) -> _MergedStyle: try: style: PygmentsStyle = pygments.styles.get_style_by_name(name) except ClassNotFound: @@ -124,7 +132,11 @@ def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle: return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) -def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle: +def style_factory_helpers( + name: str, + cli_style: dict[str, str], + warnings: bool = False, +) -> PygmentsStyle: try: style: dict[PygmentsStyle | str, str] = pygments.styles.get_style_by_name(name).styles except ClassNotFound: @@ -144,6 +156,15 @@ def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle: # TODO: cli helpers will have to switch to ptk.Style logger.error("Unhandled style / class name: %s", token) + if warnings: + for warning_token in style: + if 'Warnings' not in str(warning_token): + continue + warning_str = str(warning_token) + output_str = warning_str.replace('Warnings', 'Output') + output_token = string_to_tokentype(output_str) + style[output_token] = style[warning_token] + class OutputStyle(PygmentsStyle): default_style = "" styles = style diff --git a/mycli/main.py b/mycli/main.py index e6a22145..3d0a5b0f 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -32,13 +32,21 @@ import click from configobj import ConfigObj import keyring +from prompt_toolkit import print_formatted_text from prompt_toolkit.application.current import get_app from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import Completion, DynamicCompleter from prompt_toolkit.document import Document from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.filters import Condition, HasFocus, IsDone -from prompt_toolkit.formatted_text import ANSI, AnyFormattedText +from prompt_toolkit.formatted_text import ( + ANSI, + HTML, + AnyFormattedText, + FormattedText, + to_formatted_text, + to_plain_text, +) from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register from prompt_toolkit.key_binding.key_processor import KeyPressEvent from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor @@ -54,7 +62,7 @@ from mycli import __version__ from mycli.clibuffer import cli_is_multiline -from mycli.clistyle import style_factory, style_factory_output +from mycli.clistyle import style_factory_helpers, style_factory_toolkit from mycli.clitoolbar import create_toolbar_tokens_func from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher @@ -206,7 +214,9 @@ def __init__( self.syntax_style = c["main"]["syntax_style"] self.less_chatty = c["main"].as_bool("less_chatty") self.cli_style = c["colors"] - self.output_style = style_factory_output(self.syntax_style, self.cli_style) + self.toolkit_style = style_factory_toolkit(self.syntax_style, self.cli_style) + self.helpers_style = style_factory_helpers(self.syntax_style, self.cli_style) + self.helpers_warnings_style = style_factory_helpers(self.syntax_style, self.cli_style, warnings=True) self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn @@ -880,6 +890,13 @@ def handle_unprettify_binding(self, text: str) -> str: unpretty_text = unpretty_text + ';' return unpretty_text + def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: + self.log_output(timing) + add_style = 'class:warnings.timing' if is_warnings_style else 'class:output.timing' + formatted_timing = FormattedText([('', timing)]) + styled_timing = to_formatted_text(formatted_timing, style=add_style) + print_formatted_text(styled_timing, style=self.toolkit_style) + def run_cli(self) -> None: iterations = 0 sqlexecute = self.sqlexecute @@ -1001,7 +1018,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None: assert self.prompt_app is not None self.prompt_app.output.bell() if special.is_timing_enabled(): - self.echo(f"Time: {t:0.03f}s") + self.output_timing(f"Time: {t:0.03f}s") except KeyboardInterrupt: pass @@ -1012,7 +1029,10 @@ def output_res(results: Generator[SQLResult], start: float) -> None: # get and display warnings if enabled if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: warnings = sqlexecute.run("SHOW WARNINGS") + t = time() - start + saw_warning = False for warning in warnings: + saw_warning = True formatted = self.format_sqlresult( warning, is_expanded=special.is_expanded_output(), @@ -1021,9 +1041,13 @@ def output_res(results: Generator[SQLResult], start: float) -> None: numeric_alignment=self.numeric_alignment, binary_display=self.binary_display, max_width=max_width, + is_warnings_style=True, ) self.echo("") - self.output(formatted, warning.status) + self.output(formatted, warning.status, is_warnings_style=True) + + if saw_warning and special.is_timing_enabled(): + self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True) def keepalive_hook(_context): """ @@ -1105,7 +1129,7 @@ def one_iteration(text: str | None = None) -> None: click.echo(context) click.echo("---") if special.is_timing_enabled(): - click.echo(f"Time: {duration:.2f} seconds") + self.output_timing(f"Time: {duration:.2f} seconds") text = self.prompt_app.prompt( default=sql or '', inputhook=inputhook, @@ -1264,7 +1288,8 @@ def one_iteration(text: str | None = None) -> None: auto_suggest=AutoSuggestFromHistory(), complete_while_typing=complete_while_typing_filter, multiline=cli_is_multiline(self), - style=style_factory(self.syntax_style, self.cli_style), + # why not self.toolkit_style here? + style=style_factory_toolkit(self.syntax_style, self.cli_style), include_default_pygments_style=False, key_bindings=key_bindings, enable_open_in_editor=True, @@ -1344,8 +1369,10 @@ def log_query(self, query: str) -> None: self.logfile.write(query) self.logfile.write("\n") - def log_output(self, output: str) -> None: + def log_output(self, output: str | AnyFormattedText) -> None: """Log the output in the audit log, if it's enabled.""" + if isinstance(output, (ANSI, HTML, FormattedText)): + output = to_plain_text(output) if isinstance(self.logfile, TextIOWrapper): click.echo(output, file=self.logfile) @@ -1371,7 +1398,12 @@ def get_output_margin(self, status: str | None = None) -> int: return margin - def output(self, output: itertools.chain[str], status: str | None = None) -> None: + def output( + self, + output: itertools.chain[str], + status: str | None = None, + is_warnings_style: bool = False, + ) -> None: """Output text to stdout or a pager command. The status text is not outputted to pager or files. @@ -1433,8 +1465,12 @@ def newlinewrapper(text: list[str]) -> Generator[str, None, None]: click.secho(line) if status: + # todo allow status to be a FormattedText, but strip before logging self.log_output(status) - click.secho(status) + add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status' + formatted_status = FormattedText([('', status)]) + styled_status = to_formatted_text(formatted_status, style=add_style) + print_formatted_text(styled_status, style=self.toolkit_style) def configure_pager(self) -> None: # Provide sane defaults for less if they are empty. @@ -1576,6 +1612,7 @@ def run_query( null_string=self.null_string, numeric_alignment=self.numeric_alignment, binary_display=self.binary_display, + is_warnings_style=True, ) for line in output: click.echo(line, nl=new_line) @@ -1592,6 +1629,7 @@ def format_sqlresult( numeric_alignment: str = 'right', binary_display: str | None = None, max_width: int | None = None, + is_warnings_style: bool = False, ) -> itertools.chain[str]: if is_redirected: use_formatter = self.redirect_formatter @@ -1605,7 +1643,7 @@ def format_sqlresult( "dialect": "unix", "disable_numparse": True, "preserve_whitespace": True, - "style": self.output_style, + "style": self.helpers_warnings_style if is_warnings_style else self.helpers_style, } default_kwargs = use_formatter._output_formats[use_formatter.format_name].formatter_args diff --git a/mycli/myclirc b/mycli/myclirc index 6f65090a..dbcfc506 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -250,6 +250,14 @@ output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" output.null = "#808080" +output.status = "" +output.timing = "" +warnings.header = "#00ff5f bold" +warnings.odd-row = "" +warnings.even-row = "" +warnings.null = "#808080" +warnings.status = "" +warnings.timing = "" # SQL syntax highlighting overrides # sql.comment = 'italic #408080' diff --git a/test/myclirc b/test/myclirc index e69cdd8b..56b92dcb 100644 --- a/test/myclirc +++ b/test/myclirc @@ -248,6 +248,14 @@ output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" output.null = "#808080" +output.status = "" +output.timing = "" +warnings.header = "#00ff5f bold" +warnings.odd-row = "" +warnings.even-row = "" +warnings.null = "#808080" +warnings.status = "" +warnings.timing = "" # SQL syntax highlighting overrides # sql.comment = 'italic #408080' diff --git a/test/test_clistyle.py b/test/test_clistyle.py index cb6bdcb2..f6ac429d 100644 --- a/test/test_clistyle.py +++ b/test/test_clistyle.py @@ -6,15 +6,15 @@ from pygments.token import Token import pytest -from mycli.clistyle import style_factory +from mycli.clistyle import style_factory_toolkit @pytest.mark.skip(reason="incompatible with new prompt toolkit") -def test_style_factory(): +def test_style_factory_toolkit(): """Test that a Pygments Style class is created.""" header = "bold underline #ansired" cli_style = {"Token.Output.Header": header} - style = style_factory("default", cli_style) + style = style_factory_toolkit("default", cli_style) assert isinstance(style(), Style) assert Token.Output.Header in style.styles @@ -22,8 +22,8 @@ def test_style_factory(): @pytest.mark.skip(reason="incompatible with new prompt toolkit") -def test_style_factory_unknown_name(): +def test_style_factory_toolkit_unknown_name(): """Test that an unrecognized name will not throw an error.""" - style = style_factory("foobar", {}) + style = style_factory_toolkit("foobar", {}) assert isinstance(style(), Style) diff --git a/test/test_main.py b/test/test_main.py index 88e92b11..be3a32a5 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,7 +1,9 @@ # type: ignore from collections import namedtuple +from contextlib import redirect_stdout import csv +import io import os import shutil from tempfile import NamedTemporaryFile @@ -42,7 +44,7 @@ @dbtest -def test_binary_display_hex(executor, capsys): +def test_binary_display_hex(executor): m = MyCli() m.sqlexecute = SQLExecute( None, @@ -72,14 +74,16 @@ def test_binary_display_hex(executor, capsys): binary_display="hex", max_width=None, ) - m.output(formatted, sqlresult.status) + f = io.StringIO() + with redirect_stdout(f): + m.output(formatted, sqlresult.status) expected = " 0x6a " - stdout = capsys.readouterr().out - assert expected in stdout + output = f.getvalue() + assert expected in output @dbtest -def test_binary_display_utf8(executor, capsys): +def test_binary_display_utf8(executor): m = MyCli() m.sqlexecute = SQLExecute( None, @@ -109,10 +113,12 @@ def test_binary_display_utf8(executor, capsys): binary_display="utf8", max_width=None, ) - m.output(formatted, sqlresult.status) + f = io.StringIO() + with redirect_stdout(f): + m.output(formatted, sqlresult.status) expected = " j " - stdout = capsys.readouterr().out - assert expected in stdout + output = f.getvalue() + assert expected in output @dbtest