From 3ac50de0cab7755b7d86f15cb6954a8503f2dfbb Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Mon, 2 Mar 2026 05:18:22 -0500 Subject: [PATCH] avoid calling get_prompt() unless needed Since some prompt escapes are expensive, and can even require a trip to the server, avoid calling get_prompt() unless needed, preferring to use the cached value in the last_prompt_message property, or a new saved value for the number of lines in the prompt. Even after caching, get_prompt() seems to be called two or three times for each prompt refresh, so there is more to do. Incidentally, explicitly strip ANSI formatting from prompts before writing them to a file, when "tee" is in effect. --- changelog.md | 2 ++ mycli/main.py | 34 ++++++++++++++++++++-------- mycli/packages/special/iocommands.py | 13 +++++++---- test/test_main.py | 4 ++-- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/changelog.md b/changelog.md index f3341a94..2bf0dac8 100644 --- a/changelog.md +++ b/changelog.md @@ -15,6 +15,7 @@ Features Bug Fixes --------- * Make toolbar widths consistent on toggle actions. +* Don't write ANSI prompt escapes to `tee` output. Internal @@ -25,6 +26,7 @@ Internal * Add more URL constants. * Set `$VISUAL` whenever `$EDITOR` is set. * Fix tempfile leak in test suite. +* Avoid refreshing the prompt unless needed. 1.58.0 (2026/02/28) diff --git a/mycli/main.py b/mycli/main.py index 7aeb8aa2..31ae82b0 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -303,6 +303,7 @@ def __init__( self.my_cnf['mysqld'] = {} prompt_cnf = self.read_my_cnf(self.my_cnf, ["prompt"])["prompt"] self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt + self.prompt_lines = 0 self.multiline_continuation_char = c["main"]["prompt_continuation"] self.toolbar_format = toolbar_format or c['main']['toolbar'] self.prompt_app = None @@ -935,10 +936,13 @@ def run_cli(self) -> None: def get_prompt_message(app) -> ANSI: if app.current_buffer.text: return self.last_prompt_message - prompt = self.get_prompt(self.prompt_format) + prompt = self.get_prompt(self.prompt_format, app.render_counter) if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: - prompt = self.get_prompt(self.default_prompt_splitln) + prompt = self.get_prompt(self.default_prompt_splitln, app.render_counter) + self.prompt_lines = prompt.count('\n') + 1 prompt = prompt.replace("\\x1b", "\x1b") + if not self.prompt_lines: + self.prompt_lines = prompt.count('\n') + 1 self.last_prompt_message = ANSI(prompt) return self.last_prompt_message @@ -1182,7 +1186,8 @@ def one_iteration(text: str | None = None) -> None: try: logger.debug("sql: %r", text) - special.write_tee(self.get_prompt(self.prompt_format) + text) + special.write_tee(self.last_prompt_message, nl=False) + special.write_tee(text) self.log_query(text) successful = False @@ -1397,7 +1402,11 @@ def echo(self, s: str, **kwargs) -> None: def get_output_margin(self, status: str | None = None) -> int: """Get the output margin (number of rows for the prompt, footer and timing message.""" - margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 1 + if not self.prompt_lines: + # self.prompt_app.app.render_counter failed in the test suite + app = get_app() + self.prompt_lines = self.get_prompt(self.prompt_format, app.render_counter).count('\n') + 1 + margin = self.get_reserved_space() + self.prompt_lines if special.is_timing_enabled(): margin += 1 if status: @@ -1534,13 +1543,18 @@ def get_completions(self, text: str, cursor_position: int) -> Iterable[Completio def get_custom_toolbar(self, toolbar_format: str) -> ANSI: if self.prompt_app and self.prompt_app.app.current_buffer.text: return self.last_custom_toolbar_message - toolbar = self.get_prompt(toolbar_format) + app = get_app() + toolbar = self.get_prompt(toolbar_format, app.render_counter) toolbar = toolbar.replace("\\x1b", "\x1b") self.last_custom_toolbar_message = ANSI(toolbar) return self.last_custom_toolbar_message - # todo: time/uptime update on every character typed, instead of after every return - def get_prompt(self, string: str) -> str: + # Memoizing a method leaks the instance, but we only expect one MyCli instance. + # Before memoizing, get_prompt() was called dozens of times per prompt. + # Even after memoizing, get_prompt's logic gets called twice per prompt, which + # should be addressed, because some format strings take a trip to the server. + @functools.lru_cache(maxsize=256) # noqa: B019 + def get_prompt(self, string: str, _render_counter: int) -> str: sqlexecute = self.sqlexecute assert sqlexecute is not None assert sqlexecute.server_info is not None @@ -1569,6 +1583,8 @@ def get_prompt(self, string: str) -> str: string = string.replace("\\k", os.path.basename(sqlexecute.socket or str(sqlexecute.port))) string = string.replace("\\K", sqlexecute.socket or str(sqlexecute.port)) string = string.replace("\\A", self.dsn_alias or "(none)") + string = string.replace("\\_", " ") + # jump through hoops for the test environment, and for efficiency if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: if '\\y' in string: @@ -1581,14 +1597,13 @@ def get_prompt(self, string: str) -> str: string = string.replace('\\y', '(none)') string = string.replace('\\Y', '(none)') - string = string.replace("\\_", " ") - # jump through hoops for the test environment and for efficiency if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: if '\\T' in string: with sqlexecute.conn.cursor() as cur: string = string.replace('\\T', get_ssl_version(cur) or '(none)') else: string = string.replace('\\T', '(none)') + if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: if '\\w' in string: with sqlexecute.conn.cursor() as cur: @@ -1601,6 +1616,7 @@ def get_prompt(self, string: str) -> str: string = string.replace('\\W', str(get_warning_count(cur) or '')) else: string = string.replace('\\W', '') + return string def run_query( diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 39714075..cfcc3433 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -11,6 +11,7 @@ import click from configobj import ConfigObj +from prompt_toolkit.formatted_text import ANSI, FormattedText, to_plain_text from pymysql.cursors import Cursor import pyperclip import sqlparse @@ -432,12 +433,14 @@ def no_tee(arg: str, **_) -> list[SQLResult]: return [SQLResult(status="")] -def write_tee(output: str) -> None: +def write_tee(output: str | ANSI | FormattedText, nl: bool = True) -> None: global tee_file - if tee_file: - click.echo(output, file=tee_file, nl=False) - click.echo("\n", file=tee_file, nl=False) - tee_file.flush() + if not tee_file: + return + click.echo(to_plain_text(output), file=tee_file, nl=False) + if nl: + click.echo('\n', file=tee_file, nl=False) + tee_file.flush() @special_command("\\once", "\\once [-o] ", "Append next result to an output file (overwrite using -o).", aliases=["\\o"]) diff --git a/test/test_main.py b/test/test_main.py index be3a32a5..34ac1aaf 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -335,7 +335,7 @@ def test_prompt_no_host_only_socket(executor): mycli.sqlexecute.user = "root" mycli.sqlexecute.dbname = "mysql" mycli.sqlexecute.port = "3306" - prompt = mycli.get_prompt(mycli.prompt_format) + prompt = mycli.get_prompt(mycli.prompt_format, 0) assert prompt == "MySQL root@localhost:mysql> " @@ -350,7 +350,7 @@ def test_prompt_socket_overrides_port(executor): mycli.sqlexecute.user = "root" mycli.sqlexecute.dbname = "mysql" mycli.sqlexecute.port = "3306" - prompt = mycli.get_prompt(mycli.prompt_format) + prompt = mycli.get_prompt(mycli.prompt_format, 0) assert prompt == "MySQL root@localhost:mysqld.sock mysql> "