Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Upcoming (TBD)
Features
---------
* Offer filename completions on more special commands, such as `\edit`.
* Allow styling of status, timing, and warnings text.


Bug Fixes
Expand Down
25 changes: 23 additions & 2 deletions mycli/clistyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
Token.Output.OddRow: "output.odd-row",
Token.Output.EvenRow: "output.even-row",
Token.Output.Null: "output.null",
Token.Output.Status: "output.status",
Token.Output.Timing: "output.timing",
Token.Warnings.Header: "warnings.header",
Token.Warnings.OddRow: "warnings.odd-row",
Token.Warnings.EvenRow: "warnings.even-row",
Token.Warnings.Null: "warnings.null",
Token.Warnings.Status: "warnings.status",
Token.Warnings.Timing: "warnings.timing",
Token.Prompt: "prompt",
Token.Continuation: "continuation",
}
Expand Down Expand Up @@ -96,7 +104,7 @@ def parse_pygments_style(
return token_type, style_dict[token_name]


def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle:
def style_factory_toolkit(name: str, cli_style: dict[str, str]) -> _MergedStyle:
try:
style: PygmentsStyle = pygments.styles.get_style_by_name(name)
except ClassNotFound:
Expand Down Expand Up @@ -124,7 +132,11 @@ def style_factory(name: str, cli_style: dict[str, str]) -> _MergedStyle:
return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)])


def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle:
def style_factory_helpers(
name: str,
cli_style: dict[str, str],
warnings: bool = False,
) -> PygmentsStyle:
try:
style: dict[PygmentsStyle | str, str] = pygments.styles.get_style_by_name(name).styles
except ClassNotFound:
Expand All @@ -144,6 +156,15 @@ def style_factory_output(name: str, cli_style: dict[str, str]) -> PygmentsStyle:
# TODO: cli helpers will have to switch to ptk.Style
logger.error("Unhandled style / class name: %s", token)

if warnings:
for warning_token in style:
if 'Warnings' not in str(warning_token):
continue
warning_str = str(warning_token)
output_str = warning_str.replace('Warnings', 'Output')
output_token = string_to_tokentype(output_str)
style[output_token] = style[warning_token]

class OutputStyle(PygmentsStyle):
default_style = ""
styles = style
Expand Down
60 changes: 49 additions & 11 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,21 @@
import click
from configobj import ConfigObj
import keyring
from prompt_toolkit import print_formatted_text
from prompt_toolkit.application.current import get_app
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import Completion, DynamicCompleter
from prompt_toolkit.document import Document
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
from prompt_toolkit.filters import Condition, HasFocus, IsDone
from prompt_toolkit.formatted_text import ANSI, AnyFormattedText
from prompt_toolkit.formatted_text import (
ANSI,
HTML,
AnyFormattedText,
FormattedText,
to_formatted_text,
to_plain_text,
)
from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register
from prompt_toolkit.key_binding.key_processor import KeyPressEvent
from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
Expand All @@ -54,7 +62,7 @@

from mycli import __version__
from mycli.clibuffer import cli_is_multiline
from mycli.clistyle import style_factory, style_factory_output
from mycli.clistyle import style_factory_helpers, style_factory_toolkit
from mycli.clitoolbar import create_toolbar_tokens_func
from mycli.compat import WIN
from mycli.completion_refresher import CompletionRefresher
Expand Down Expand Up @@ -206,7 +214,9 @@ def __init__(
self.syntax_style = c["main"]["syntax_style"]
self.less_chatty = c["main"].as_bool("less_chatty")
self.cli_style = c["colors"]
self.output_style = style_factory_output(self.syntax_style, self.cli_style)
self.toolkit_style = style_factory_toolkit(self.syntax_style, self.cli_style)
self.helpers_style = style_factory_helpers(self.syntax_style, self.cli_style)
self.helpers_warnings_style = style_factory_helpers(self.syntax_style, self.cli_style, warnings=True)
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
c_dest_warning = c["main"].as_bool("destructive_warning")
self.destructive_warning = c_dest_warning if warn is None else warn
Expand Down Expand Up @@ -880,6 +890,13 @@ def handle_unprettify_binding(self, text: str) -> str:
unpretty_text = unpretty_text + ';'
return unpretty_text

def output_timing(self, timing: str, is_warnings_style: bool = False) -> None:
self.log_output(timing)
add_style = 'class:warnings.timing' if is_warnings_style else 'class:output.timing'
formatted_timing = FormattedText([('', timing)])
styled_timing = to_formatted_text(formatted_timing, style=add_style)
print_formatted_text(styled_timing, style=self.toolkit_style)

def run_cli(self) -> None:
iterations = 0
sqlexecute = self.sqlexecute
Expand Down Expand Up @@ -1001,7 +1018,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
assert self.prompt_app is not None
self.prompt_app.output.bell()
if special.is_timing_enabled():
self.echo(f"Time: {t:0.03f}s")
self.output_timing(f"Time: {t:0.03f}s")
except KeyboardInterrupt:
pass

Expand All @@ -1012,7 +1029,10 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
# get and display warnings if enabled
if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0:
warnings = sqlexecute.run("SHOW WARNINGS")
t = time() - start
saw_warning = False
for warning in warnings:
saw_warning = True
formatted = self.format_sqlresult(
warning,
is_expanded=special.is_expanded_output(),
Expand All @@ -1021,9 +1041,13 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
numeric_alignment=self.numeric_alignment,
binary_display=self.binary_display,
max_width=max_width,
is_warnings_style=True,
)
self.echo("")
self.output(formatted, warning.status)
self.output(formatted, warning.status, is_warnings_style=True)

if saw_warning and special.is_timing_enabled():
self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True)

def keepalive_hook(_context):
"""
Expand Down Expand Up @@ -1105,7 +1129,7 @@ def one_iteration(text: str | None = None) -> None:
click.echo(context)
click.echo("---")
if special.is_timing_enabled():
click.echo(f"Time: {duration:.2f} seconds")
self.output_timing(f"Time: {duration:.2f} seconds")
text = self.prompt_app.prompt(
default=sql or '',
inputhook=inputhook,
Expand Down Expand Up @@ -1264,7 +1288,8 @@ def one_iteration(text: str | None = None) -> None:
auto_suggest=AutoSuggestFromHistory(),
complete_while_typing=complete_while_typing_filter,
multiline=cli_is_multiline(self),
style=style_factory(self.syntax_style, self.cli_style),
# why not self.toolkit_style here?
style=style_factory_toolkit(self.syntax_style, self.cli_style),
include_default_pygments_style=False,
key_bindings=key_bindings,
enable_open_in_editor=True,
Expand Down Expand Up @@ -1344,8 +1369,10 @@ def log_query(self, query: str) -> None:
self.logfile.write(query)
self.logfile.write("\n")

def log_output(self, output: str) -> None:
def log_output(self, output: str | AnyFormattedText) -> None:
"""Log the output in the audit log, if it's enabled."""
if isinstance(output, (ANSI, HTML, FormattedText)):
output = to_plain_text(output)
if isinstance(self.logfile, TextIOWrapper):
click.echo(output, file=self.logfile)

Expand All @@ -1371,7 +1398,12 @@ def get_output_margin(self, status: str | None = None) -> int:

return margin

def output(self, output: itertools.chain[str], status: str | None = None) -> None:
def output(
self,
output: itertools.chain[str],
status: str | None = None,
is_warnings_style: bool = False,
) -> None:
"""Output text to stdout or a pager command.

The status text is not outputted to pager or files.
Expand Down Expand Up @@ -1433,8 +1465,12 @@ def newlinewrapper(text: list[str]) -> Generator[str, None, None]:
click.secho(line)

if status:
# todo allow status to be a FormattedText, but strip before logging
self.log_output(status)
click.secho(status)
add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status'
formatted_status = FormattedText([('', status)])
styled_status = to_formatted_text(formatted_status, style=add_style)
print_formatted_text(styled_status, style=self.toolkit_style)

def configure_pager(self) -> None:
# Provide sane defaults for less if they are empty.
Expand Down Expand Up @@ -1576,6 +1612,7 @@ def run_query(
null_string=self.null_string,
numeric_alignment=self.numeric_alignment,
binary_display=self.binary_display,
is_warnings_style=True,
)
for line in output:
click.echo(line, nl=new_line)
Expand All @@ -1592,6 +1629,7 @@ def format_sqlresult(
numeric_alignment: str = 'right',
binary_display: str | None = None,
max_width: int | None = None,
is_warnings_style: bool = False,
) -> itertools.chain[str]:
if is_redirected:
use_formatter = self.redirect_formatter
Expand All @@ -1605,7 +1643,7 @@ def format_sqlresult(
"dialect": "unix",
"disable_numparse": True,
"preserve_whitespace": True,
"style": self.output_style,
"style": self.helpers_warnings_style if is_warnings_style else self.helpers_style,
}
default_kwargs = use_formatter._output_formats[use_formatter.format_name].formatter_args

Expand Down
8 changes: 8 additions & 0 deletions mycli/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ output.header = "#00ff5f bold"
output.odd-row = ""
output.even-row = ""
output.null = "#808080"
output.status = ""
output.timing = ""
warnings.header = "#00ff5f bold"
warnings.odd-row = ""
warnings.even-row = ""
warnings.null = "#808080"
warnings.status = ""
warnings.timing = ""

# SQL syntax highlighting overrides
# sql.comment = 'italic #408080'
Expand Down
8 changes: 8 additions & 0 deletions test/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ output.header = "#00ff5f bold"
output.odd-row = ""
output.even-row = ""
output.null = "#808080"
output.status = ""
output.timing = ""
warnings.header = "#00ff5f bold"
warnings.odd-row = ""
warnings.even-row = ""
warnings.null = "#808080"
warnings.status = ""
warnings.timing = ""

# SQL syntax highlighting overrides
# sql.comment = 'italic #408080'
Expand Down
10 changes: 5 additions & 5 deletions test/test_clistyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
from pygments.token import Token
import pytest

from mycli.clistyle import style_factory
from mycli.clistyle import style_factory_toolkit


@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory():
def test_style_factory_toolkit():
"""Test that a Pygments Style class is created."""
header = "bold underline #ansired"
cli_style = {"Token.Output.Header": header}
style = style_factory("default", cli_style)
style = style_factory_toolkit("default", cli_style)

assert isinstance(style(), Style)
assert Token.Output.Header in style.styles
assert header == style.styles[Token.Output.Header]


@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory_unknown_name():
def test_style_factory_toolkit_unknown_name():
"""Test that an unrecognized name will not throw an error."""
style = style_factory("foobar", {})
style = style_factory_toolkit("foobar", {})

assert isinstance(style(), Style)
22 changes: 14 additions & 8 deletions test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# type: ignore

from collections import namedtuple
from contextlib import redirect_stdout
import csv
import io
import os
import shutil
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -42,7 +44,7 @@


@dbtest
def test_binary_display_hex(executor, capsys):
def test_binary_display_hex(executor):
m = MyCli()
m.sqlexecute = SQLExecute(
None,
Expand Down Expand Up @@ -72,14 +74,16 @@ def test_binary_display_hex(executor, capsys):
binary_display="hex",
max_width=None,
)
m.output(formatted, sqlresult.status)
f = io.StringIO()
with redirect_stdout(f):
m.output(formatted, sqlresult.status)
expected = " 0x6a "
stdout = capsys.readouterr().out
assert expected in stdout
output = f.getvalue()
assert expected in output


@dbtest
def test_binary_display_utf8(executor, capsys):
def test_binary_display_utf8(executor):
m = MyCli()
m.sqlexecute = SQLExecute(
None,
Expand Down Expand Up @@ -109,10 +113,12 @@ def test_binary_display_utf8(executor, capsys):
binary_display="utf8",
max_width=None,
)
m.output(formatted, sqlresult.status)
f = io.StringIO()
with redirect_stdout(f):
m.output(formatted, sqlresult.status)
expected = " j "
stdout = capsys.readouterr().out
assert expected in stdout
output = f.getvalue()
assert expected in output


@dbtest
Expand Down