From 43288ea1e95f2097e828ca996700b787ddeeb1d7 Mon Sep 17 00:00:00 2001 From: pablogonzalezpe Date: Mon, 13 Apr 2026 02:00:41 +0200 Subject: [PATCH 1/2] contrib/mysql: add classic protocol support --- scapy/contrib/mysql.py | 965 +++++++++++++++++++++++++++++++++++++++++ test/contrib/mysql.uts | 453 +++++++++++++++++++ 2 files changed, 1418 insertions(+) create mode 100644 scapy/contrib/mysql.py create mode 100644 test/contrib/mysql.uts diff --git a/scapy/contrib/mysql.py b/scapy/contrib/mysql.py new file mode 100644 index 00000000000..73f9364e7a4 --- /dev/null +++ b/scapy/contrib/mysql.py @@ -0,0 +1,965 @@ +# SPDX-License-Identifier: GPL-2.0-only +# This file is part of Scapy +# See https://scapy.net/ for more information +# Author: Pablo Gonzalez + +# scapy.contrib.description = MySQL client/server protocol +# scapy.contrib.status = loads + +""" +MySQL client/server protocol. + +This contrib module implements support for the MySQL classic protocol over TCP, +including packet framing, common handshake/authentication messages, query +packets, text resultsets, prepared statement metadata, and some legacy flows +seen in real captures. + +Currently supported messages include: + +- Protocol::HandshakeV10 +- Protocol::SSLRequest +- Protocol::HandshakeResponse41 +- OldAuthSwitchRequest +- AuthSwitchRequest +- AuthSwitchResponse +- AuthMoreData +- OK_Packet +- ERR_Packet +- EOF_Packet +- COM_QUERY +- COM_STMT_PREPARE_OK +- text resultset column counts, column definitions, and rows + +This module does not currently implement TLS-encrypted MySQL payloads, +compression, binary resultsets, or full command/authentication coverage. +""" + +import struct +from typing import Any, Optional, Tuple + +from scapy.compat import orb +from scapy.fields import ( + ByteEnumField, + ByteField, + ConditionalField, + Field, + LEIntField, + LEShortEnumField, + LEShortField, + LEThreeBytesField, + PacketListField, + StrField, + StrFixedLenField, + StrLenField, + StrNullField, +) +from scapy.layers.inet import TCP +from scapy.packet import Packet, Raw, bind_layers +from scapy.sessions import TCPSession + +__all__ = [ + "MYSQL_PORT", + "CLIENT_PROTOCOL_41", + "CLIENT_SSL", + "CLIENT_CONNECT_WITH_DB", + "CLIENT_SECURE_CONNECTION", + "CLIENT_PLUGIN_AUTH", + "CLIENT_CONNECT_ATTRS", + "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA", + "CLIENT_DEPRECATE_EOF", + "MySQLClient", + "MySQLServer", + "MySQLClientPacket", + "MySQLServerPacket", + "MySQLHandshakeV10", + "MySQLSSLRequest", + "MySQLHandshakeResponse41", + "MySQLOldAuthSwitchRequest", + "MySQLAuthSwitchRequest", + "MySQLAuthMoreData", + "MySQLAuthSwitchResponse", + "MySQLStmtPrepareOK", + "MySQLResultSetColumnCount", + "MySQLColumnDefinition41", + "MySQLTextResultSetRow", + "MySQLOKPacket", + "MySQLErrPacket", + "MySQLEOFPacket", + "MySQLCommand", + "MySQLComQuery", +] + +MYSQL_PORT = 3306 + +CLIENT_LONG_PASSWORD = 0x00000001 +CLIENT_LONG_FLAG = 0x00000004 +CLIENT_CONNECT_WITH_DB = 0x00000008 +CLIENT_COMPRESS = 0x00000020 +CLIENT_LOCAL_FILES = 0x00000080 +CLIENT_PROTOCOL_41 = 0x00000200 +CLIENT_SSL = 0x00000800 +CLIENT_TRANSACTIONS = 0x00002000 +CLIENT_SECURE_CONNECTION = 0x00008000 +CLIENT_MULTI_STATEMENTS = 0x00010000 +CLIENT_MULTI_RESULTS = 0x00020000 +CLIENT_PLUGIN_AUTH = 0x00080000 +CLIENT_CONNECT_ATTRS = 0x00100000 +CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000 +CLIENT_SESSION_TRACK = 0x00800000 +CLIENT_DEPRECATE_EOF = 0x01000000 +CLIENT_ZSTD_COMPRESSION_ALGORITHM = 0x04000000 +CLIENT_QUERY_ATTRIBUTES = 0x08000000 + +MYSQL_COMMANDS = { + 0x01: "COM_QUIT", + 0x02: "COM_INIT_DB", + 0x03: "COM_QUERY", + 0x04: "COM_FIELD_LIST", + 0x0E: "COM_PING", + 0x16: "COM_STMT_PREPARE", + 0x17: "COM_STMT_EXECUTE", + 0x19: "COM_STMT_CLOSE", +} + +MYSQL_CHARACTER_SETS = { + 0x08: "latin1_swedish_ci", + 0x21: "utf8_general_ci", + 0x2D: "utf8mb4_general_ci", + 0x2E: "utf8mb4_bin", + 0x3F: "binary", + 0xFF: "utf8mb4_0900_ai_ci", +} + +MYSQL_CLIENT_FLAGS = { + CLIENT_LONG_PASSWORD: "LONG_PASSWORD", + CLIENT_LONG_FLAG: "LONG_FLAG", + CLIENT_CONNECT_WITH_DB: "CONNECT_WITH_DB", + CLIENT_COMPRESS: "COMPRESS", + CLIENT_LOCAL_FILES: "LOCAL_FILES", + CLIENT_PROTOCOL_41: "PROTOCOL_41", + CLIENT_SSL: "SSL", + CLIENT_TRANSACTIONS: "TRANSACTIONS", + CLIENT_SECURE_CONNECTION: "SECURE_CONNECTION", + CLIENT_MULTI_STATEMENTS: "MULTI_STATEMENTS", + CLIENT_MULTI_RESULTS: "MULTI_RESULTS", + CLIENT_PLUGIN_AUTH: "PLUGIN_AUTH", + CLIENT_CONNECT_ATTRS: "CONNECT_ATTRS", + CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: "PLUGIN_AUTH_LENENC_CLIENT_DATA", + CLIENT_SESSION_TRACK: "SESSION_TRACK", + CLIENT_DEPRECATE_EOF: "DEPRECATE_EOF", + CLIENT_ZSTD_COMPRESSION_ALGORITHM: "ZSTD_COMPRESSION_ALGORITHM", + CLIENT_QUERY_ATTRIBUTES: "QUERY_ATTRIBUTES", +} + +MYSQL_STATUS_FLAGS = { + 0x0001: "IN_TRANS", + 0x0002: "AUTOCOMMIT", + 0x0008: "MORE_RESULTS_EXISTS", + 0x0010: "NO_GOOD_INDEX_USED", + 0x0020: "NO_INDEX_USED", + 0x0040: "CURSOR_EXISTS", + 0x0080: "LAST_ROW_SENT", + 0x0100: "DB_DROPPED", + 0x0200: "NO_BACKSLASH_ESCAPES", + 0x0400: "METADATA_CHANGED", + 0x0800: "QUERY_WAS_SLOW", + 0x1000: "PS_OUT_PARAMS", + 0x2000: "IN_TRANS_READONLY", + 0x4000: "SESSION_STATE_CHANGED", +} + +MYSQL_COLUMN_TYPES = { + 0x00: "DECIMAL", + 0x01: "TINY", + 0x02: "SHORT", + 0x03: "LONG", + 0x04: "FLOAT", + 0x05: "DOUBLE", + 0x06: "NULL", + 0x07: "TIMESTAMP", + 0x08: "LONGLONG", + 0x09: "INT24", + 0x0A: "DATE", + 0x0B: "TIME", + 0x0C: "DATETIME", + 0x0D: "YEAR", + 0x0F: "VARCHAR", + 0x10: "BIT", + 0xF5: "JSON", + 0xF6: "NEWDECIMAL", + 0xF7: "ENUM", + 0xF8: "SET", + 0xF9: "TINY_BLOB", + 0xFA: "MEDIUM_BLOB", + 0xFB: "LONG_BLOB", + 0xFC: "BLOB", + 0xFD: "VAR_STRING", + 0xFE: "STRING", + 0xFF: "GEOMETRY", +} + +MYSQL_COLUMN_FLAGS = { + 0x0001: "NOT_NULL", + 0x0002: "PRI_KEY", + 0x0004: "UNIQUE_KEY", + 0x0008: "MULTIPLE_KEY", + 0x0010: "BLOB", + 0x0020: "UNSIGNED", + 0x0040: "ZEROFILL", + 0x0080: "BINARY", + 0x0100: "ENUM", + 0x0200: "AUTO_INCREMENT", + 0x0400: "TIMESTAMP", + 0x0800: "SET", +} + + +def _capability(flags: int, mask: int) -> bool: + return bool(flags & mask) + + +def _flag_repr(value: int, mapping: Any) -> str: + names = [name for mask, name in mapping.items() if value & mask] + if names: + return "%d (%s)" % (value, "|".join(names)) + return repr(value) + + +def _read_lenenc_int(data: bytes) -> Tuple[int, int]: + """Decode a MySQL length-encoded integer and return (value, consumed).""" + if not data: + return 0, 0 + first = orb(data[0]) + if first < 0xFB: + return first, 1 + if first == 0xFC: + return struct.unpack(" bytes: + if value < 0xFB: + return struct.pack("B", value) + if value < (1 << 16): + return b"\xFC" + struct.pack(" bool: + remain = payload + parsed = 0 + while remain and parsed < column_count: + if remain[:1] == b"\xFB": + remain = remain[1:] + parsed += 1 + continue + try: + length, size = _read_lenenc_int(remain) + except struct.error: + return False + end = size + length + if end > len(remain): + return False + remain = remain[end:] + parsed += 1 + return parsed == column_count and not remain + + +def _can_parse_column_definition(payload: bytes) -> bool: + try: + pkt = MySQLColumnDefinition41(payload) + return bytes(pkt) == payload + except Exception: + return False + + +class MySQLLenEncIntField(Field[Any, Any]): + def __init__(self, name: str, default: Any = 0) -> None: + Field.__init__(self, name, default) + + def addfield(self, pkt: Packet, s: bytes, val: Any) -> bytes: + if val is None: + val = 0 + return s + _build_lenenc_int(int(val)) + + def getfield(self, pkt: Packet, s: bytes) -> Tuple[bytes, Any]: + value, size = _read_lenenc_int(s) + return s[size:], value + + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return repr(val) + + +class MySQLCapabilityFlagsField(LEIntField): + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return _flag_repr(int(val), MYSQL_CLIENT_FLAGS) + + +class MySQLStatusFlagsField(LEShortField): + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return _flag_repr(int(val), MYSQL_STATUS_FLAGS) + + +class MySQLCharsetField(ByteEnumField): + def __init__(self, name: str, default: int = 0) -> None: + ByteEnumField.__init__(self, name, default, MYSQL_CHARACTER_SETS) + + +class MySQLShortCharsetField(LEShortEnumField): + def __init__(self, name: str, default: int = 0) -> None: + LEShortEnumField.__init__(self, name, default, MYSQL_CHARACTER_SETS) + + +class MySQLColumnFlagsField(LEShortField): + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return _flag_repr(int(val), MYSQL_COLUMN_FLAGS) + + +class MySQLLenEncStrField(Field[Any, Any]): + def __init__(self, name: str, default: Any = b"") -> None: + Field.__init__(self, name, default) + + def addfield(self, pkt: Packet, s: bytes, val: Any) -> bytes: + if val is None: + val = b"" + if isinstance(val, str): + val = val.encode("utf-8") + return s + _build_lenenc_int(len(val)) + val + + def getfield(self, pkt: Packet, s: bytes) -> Tuple[bytes, Any]: + length, size = _read_lenenc_int(s) + start = size + end = size + length + return s[end:], s[start:end] + + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return repr(val) + + +class MySQLAuthResponseField(Field[Any, Any]): + """ + Authentication response encoding depends on client capabilities. + + - CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: string + - CLIENT_SECURE_CONNECTION: int<1> + string + - otherwise: string + """ + + def __init__(self, name: str, default: Any = b"") -> None: + Field.__init__(self, name, default) + + def addfield(self, pkt: Packet, s: bytes, val: Any) -> bytes: + if val is None: + val = b"" + if isinstance(val, str): + val = val.encode("utf-8") + flags = getattr(pkt, "client_flags", 0) + if _capability(flags, CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA): + return s + _build_lenenc_int(len(val)) + val + if _capability(flags, CLIENT_SECURE_CONNECTION): + return s + struct.pack("B", len(val)) + val + return s + val + b"\x00" + + def getfield(self, pkt: Packet, s: bytes) -> Tuple[bytes, Any]: + flags = getattr(pkt, "client_flags", 0) + if _capability(flags, CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA): + length, size = _read_lenenc_int(s) + start = size + end = size + length + return s[end:], s[start:end] + if _capability(flags, CLIENT_SECURE_CONNECTION): + if not s: + return s, b"" + length = orb(s[0]) + return s[1 + length:], s[1:1 + length] + end = s.find(b"\x00") + if end < 0: + return b"", s + return s[end + 1:], s[:end] + + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return repr(val) + + +class MySQLTextRowValuesField(Field[Any, Any]): + def __init__(self, name: str, default: Any = None) -> None: + Field.__init__(self, name, [] if default is None else default) + + def addfield(self, pkt: Packet, s: bytes, val: Any) -> bytes: + if val is None: + val = [] + elif isinstance(val, (bytes, str)): + val = [val] + for item in val: + if item is None: + s += b"\xFB" + continue + if isinstance(item, str): + item = item.encode("utf-8") + s += _build_lenenc_int(len(item)) + item + return s + + def getfield(self, pkt: Packet, s: bytes) -> Tuple[bytes, Any]: + values = [] + remain = s + while remain: + if remain[:1] == b"\xFB": + values.append(None) + remain = remain[1:] + continue + length, size = _read_lenenc_int(remain) + start = size + end = size + length + values.append(remain[start:end]) + remain = remain[end:] + return b"", values + + def i2repr(self, pkt: Optional[Packet], val: Any) -> str: + return repr(val) + + +class _MySQLPacket(Packet): + fields_desc = [ + LEThreeBytesField("payload_length", None), + ByteField("sequence_id", 0), + ] + + def do_build(self) -> bytes: + pkt = self.self_build() + pay = self.do_build_payload() + return self.post_build(pkt, pay) + + def post_build(self, pkt: bytes, pay: bytes) -> bytes: + if self.payload_length is None: + pkt = struct.pack(" Tuple[bytes, bytes]: + length = self.payload_length or 0 + return s[:length], s[length:] + + +class MySQLHandshakeV10(Packet): + name = "MySQL HandshakeV10" + fields_desc = [ + ByteField("protocol_version", 10), + StrNullField("server_version", b""), + LEIntField("connection_id", 0), + StrFixedLenField("auth_plugin_data_part_1", b"", length=8), + ByteField("filler", 0), + LEShortField("capability_flags_lower", 0), + MySQLCharsetField("character_set", 0), + MySQLStatusFlagsField("status_flags", 0), + LEShortField("capability_flags_upper", 0), + ByteField("auth_plugin_data_len", 0), + StrFixedLenField("reserved", b"\x00" * 10, length=10), + StrLenField( + "auth_plugin_data_part_2", + b"", + length_from=lambda pkt: max(13, pkt.auth_plugin_data_len - 8) + if pkt.auth_plugin_data_len + else 0, + ), + StrNullField("auth_plugin_name", b""), + ] + + @property + def capability_flags(self) -> int: + return ( + ((self.capability_flags_upper & 0xFFFF) << 16) + | (self.capability_flags_lower & 0xFFFF) + ) + + +class MySQLSSLRequest(Packet): + name = "MySQL SSLRequest" + fields_desc = [ + MySQLCapabilityFlagsField( + "client_flags", + CLIENT_PROTOCOL_41 | CLIENT_SSL, + ), + LEIntField("max_packet_size", 0), + MySQLCharsetField("character_set", 0), + StrFixedLenField("filler", b"\x00" * 23, length=23), + ] + + +class MySQLHandshakeResponse41(Packet): + name = "MySQL HandshakeResponse41" + fields_desc = [ + MySQLCapabilityFlagsField("client_flags", CLIENT_PROTOCOL_41), + LEIntField("max_packet_size", 0), + MySQLCharsetField("character_set", 0), + StrFixedLenField("filler", b"\x00" * 23, length=23), + StrNullField("username", b""), + MySQLAuthResponseField("auth_response", b""), + ConditionalField( + StrNullField("database", b""), + lambda pkt: _capability(pkt.client_flags, CLIENT_CONNECT_WITH_DB), + ), + ConditionalField( + StrNullField("auth_plugin_name", b""), + lambda pkt: _capability(pkt.client_flags, CLIENT_PLUGIN_AUTH), + ), + ConditionalField( + MySQLLenEncStrField("connect_attrs", b""), + lambda pkt: _capability(pkt.client_flags, CLIENT_CONNECT_ATTRS), + ), + ConditionalField( + ByteField("zstd_compression_level", 0), + lambda pkt: _capability( + pkt.client_flags, + CLIENT_ZSTD_COMPRESSION_ALGORITHM, + ), + ), + ] + + +class MySQLAuthSwitchRequest(Packet): + name = "MySQL AuthSwitchRequest" + fields_desc = [ + ByteField("header", 0xFE), + StrNullField("plugin_name", b""), + StrField("plugin_data", b""), + ] + + +class MySQLAuthSwitchResponse(Packet): + name = "MySQL AuthSwitchResponse" + fields_desc = [ + StrField("data", b""), + ] + + +class MySQLOldAuthSwitchRequest(Packet): + name = "MySQL OldAuthSwitchRequest" + fields_desc = [ + ByteField("header", 0xFE), + ] + + +class MySQLAuthMoreData(Packet): + name = "MySQL AuthMoreData" + fields_desc = [ + ByteField("header", 0x01), + StrField("data", b""), + ] + + +class MySQLStmtPrepareOK(Packet): + name = "MySQL COM_STMT_PREPARE_OK" + fields_desc = [ + ByteField("status", 0x00), + LEIntField("statement_id", 0), + LEShortField("num_columns", 0), + LEShortField("num_params", 0), + ByteField("reserved_1", 0), + LEShortField("warning_count", 0), + ] + + +class MySQLResultSetColumnCount(Packet): + name = "MySQL ResultSet Column Count" + fields_desc = [ + MySQLLenEncIntField("column_count", 0), + ] + + +class MySQLColumnDefinition41(Packet): + name = "MySQL ColumnDefinition41" + fields_desc = [ + MySQLLenEncStrField("catalog", b"def"), + MySQLLenEncStrField("schema", b""), + MySQLLenEncStrField("table", b""), + MySQLLenEncStrField("org_table", b""), + MySQLLenEncStrField("column_name", b""), + MySQLLenEncStrField("org_column_name", b""), + MySQLLenEncIntField("fixed_length_fields_len", 0x0C), + MySQLShortCharsetField("character_set", 0), + LEIntField("column_length", 0), + ByteEnumField("column_type", 0xFD, MYSQL_COLUMN_TYPES), + MySQLColumnFlagsField("flags", 0), + ByteField("decimals", 0), + LEShortField("filler", 0), + ] + + +class MySQLTextResultSetRow(Packet): + name = "MySQL Text ResultSet Row" + fields_desc = [ + MySQLTextRowValuesField("values", []), + ] + + def do_build(self) -> bytes: + pkt = self.self_build() + pay = self.do_build_payload() + return self.post_build(pkt, pay) + + +class MySQLOKPacket(Packet): + name = "MySQL OK_Packet" + fields_desc = [ + ByteField("header", 0x00), + MySQLLenEncIntField("affected_rows", 0), + MySQLLenEncIntField("last_insert_id", 0), + MySQLStatusFlagsField("status_flags", 0), + LEShortField("warnings", 0), + StrField("info", b""), + ] + + +class MySQLErrPacket(Packet): + name = "MySQL ERR_Packet" + fields_desc = [ + ByteField("header", 0xFF), + LEShortField("error_code", 0), + StrFixedLenField("sql_state_marker", b"#", length=1), + StrFixedLenField("sql_state", b"HY000", length=5), + StrField("error_message", b""), + ] + + +class MySQLEOFPacket(Packet): + name = "MySQL EOF_Packet" + fields_desc = [ + ByteField("header", 0xFE), + ConditionalField( + LEShortField("warnings", 0), + lambda pkt: getattr( + getattr(pkt, "underlayer", None), + "payload_length", + None, + ) != 1, + ), + ConditionalField( + MySQLStatusFlagsField("status_flags", 0), + lambda pkt: getattr( + getattr(pkt, "underlayer", None), + "payload_length", + None, + ) != 1, + ), + ] + + +class MySQLCommand(Packet): + name = "MySQL Command" + fields_desc = [ + ByteEnumField("cmd", 0x03, MYSQL_COMMANDS), + StrField("data", b""), + ] + + +class MySQLComQuery(Packet): + name = "MySQL COM_QUERY" + fields_desc = [ + ByteEnumField("cmd", 0x03, MYSQL_COMMANDS), + StrField("query", b""), + ] + + +def _guess_mysql_client_payload( + pkt: _MySQLPacket, + payload: bytes, +) -> type: + if len(payload) >= 32 and pkt.sequence_id == 1: + flags = struct.unpack(" 1: + return MySQLAuthSwitchResponse + return Raw + + +def _guess_mysql_server_payload( + pkt: _MySQLPacket, + payload: bytes, +) -> type: + if payload and pkt.sequence_id == 0 and orb(payload[0]) == 0x0A: + return MySQLHandshakeV10 + if payload: + header = orb(payload[0]) + if header == 0x00: + if pkt.sequence_id == 1 and len(payload) == 12: + return MySQLStmtPrepareOK + return MySQLOKPacket + if header == 0x01 and len(payload) > 1 and pkt.sequence_id > 0: + return MySQLAuthMoreData + if header == 0xFF: + return MySQLErrPacket + if header == 0xFE and len(payload) >= 9: + return MySQLAuthSwitchRequest + if header == 0xFE and len(payload) == 1: + if pkt.sequence_id == 2: + return MySQLOldAuthSwitchRequest + return MySQLEOFPacket + if header == 0xFE and len(payload) < 9: + return MySQLEOFPacket + return Raw + + +class MySQLClientPacket(_MySQLPacket): + name = "MySQL Client Packet" + + def guess_payload_class(self, payload: bytes) -> type: + return _guess_mysql_client_payload(self, payload) + + +class MySQLServerPacket(_MySQLPacket): + name = "MySQL Server Packet" + + def guess_payload_class(self, payload: bytes) -> type: + return _guess_mysql_server_payload(self, payload) + + +class _MySQLServerResultSetColumnCountPacket(MySQLServerPacket): + def guess_payload_class(self, payload: bytes) -> type: + return MySQLResultSetColumnCount + + +class _MySQLServerColumnDefinitionPacket(MySQLServerPacket): + def guess_payload_class(self, payload: bytes) -> type: + return MySQLColumnDefinition41 + + +class _MySQLServerTextResultSetRowPacket(MySQLServerPacket): + def guess_payload_class(self, payload: bytes) -> type: + return MySQLTextResultSetRow + + +class _MySQLServerEOFPacket(MySQLServerPacket): + def guess_payload_class(self, payload: bytes) -> type: + return MySQLEOFPacket + + +def _mysql_client_cls( + pkt: Packet, + lst: Any, + cur: bytes, + remain: bytes, +) -> Optional[type]: + if len(remain) < 4: + return None + return MySQLClientPacket + + +def _mysql_server_resultset_state(lst: Any) -> Optional[Any]: + state = None + for item in lst: + payload = getattr(item, "payload", None) + if isinstance(payload, MySQLResultSetColumnCount): + state = { + "column_count": payload.column_count, + "column_defs": 0, + "metadata_done": False, + } + continue + if state is None: + continue + if isinstance(payload, MySQLColumnDefinition41): + state["column_defs"] += 1 + continue + if isinstance(payload, MySQLEOFPacket): + if not state["metadata_done"] and ( + state["column_defs"] >= state["column_count"] + ): + state["metadata_done"] = True + elif state["metadata_done"]: + state = None + continue + if isinstance(payload, MySQLOKPacket): + if not state["metadata_done"] and ( + state["column_defs"] >= state["column_count"] + ): + state["metadata_done"] = True + elif state["metadata_done"]: + state = None + continue + if isinstance(payload, MySQLErrPacket): + state = None + continue + return state + + +def _mysql_server_stmt_prepare_state(lst: Any) -> Optional[Any]: + state = None + for item in lst: + payload = getattr(item, "payload", None) + if isinstance(payload, MySQLStmtPrepareOK): + phase = None + if payload.num_params: + phase = "params" + elif payload.num_columns: + phase = "columns" + state = { + "params_remaining": payload.num_params, + "columns_remaining": payload.num_columns, + "phase": phase, + } + continue + if state is None: + continue + if isinstance(payload, MySQLColumnDefinition41): + if state["phase"] == "params" and state["params_remaining"] > 0: + state["params_remaining"] -= 1 + if state["params_remaining"] == 0: + state["phase"] = "params_eof" + elif ( + state["phase"] == "columns" + and state["columns_remaining"] > 0 + ): + state["columns_remaining"] -= 1 + if state["columns_remaining"] == 0: + state["phase"] = "columns_eof" + continue + if isinstance(payload, MySQLEOFPacket): + if state["phase"] == "params_eof": + if state["columns_remaining"] > 0: + state["phase"] = "columns" + else: + state = None + elif state["phase"] == "columns_eof": + state = None + continue + if isinstance(payload, MySQLErrPacket): + state = None + continue + return state + + +def _mysql_server_field_list_state(lst: Any) -> Optional[Any]: + state = None + for item in lst: + payload = getattr(item, "payload", None) + if isinstance(payload, MySQLColumnDefinition41): + if state is None: + state = {"metadata_done": False} + continue + if state is None: + continue + if isinstance(payload, MySQLEOFPacket): + state["metadata_done"] = True + state = None + continue + if isinstance(payload, MySQLErrPacket): + state = None + continue + state = None + return state + + +def _mysql_server_cls( + pkt: Packet, + lst: Any, + cur: bytes, + remain: bytes, +) -> Optional[type]: + if len(remain) < 4: + return None + payload_length = struct.unpack("= 7: + return MySQLServerPacket + return MySQLServerPacket + if payload: + header = orb(payload[0]) + if header in (0x00, 0x0A, 0xFE, 0xFF): + return MySQLServerPacket + if header == 0x01 and payload_length > 1: + return MySQLServerPacket + if _can_parse_column_definition(payload): + return _MySQLServerColumnDefinitionPacket + if header != 0xFB: + return _MySQLServerResultSetColumnCountPacket + return MySQLServerPacket + + +def _mysql_stream_complete(data: bytes) -> bool: + offset = 0 + while offset < len(data): + if len(data) - offset < 4: + return False + payload_length = struct.unpack( + " Optional[Packet]: + if data and _mysql_stream_complete(data): + return cls(data) + return None + + +class MySQLClient(_MySQLStream): + name = "MySQL Client Stream" + fields_desc = [ + PacketListField("contents", [], next_cls_cb=_mysql_client_cls), + ] + + +class MySQLServer(_MySQLStream): + name = "MySQL Server Stream" + fields_desc = [ + PacketListField("contents", [], next_cls_cb=_mysql_server_cls), + ] + + +bind_layers(TCP, MySQLClient, dport=MYSQL_PORT) +bind_layers(TCP, MySQLServer, sport=MYSQL_PORT) diff --git a/test/contrib/mysql.uts b/test/contrib/mysql.uts new file mode 100644 index 00000000000..f87d94c6cd3 --- /dev/null +++ b/test/contrib/mysql.uts @@ -0,0 +1,453 @@ +# MySQL related regression tests +# Author: Pablo Gonzalez +# +# Type the following command to launch the tests: +# $ test/run_tests -P "load_contrib('mysql')" -t test/contrib/mysql.uts + ++ mysql + += mysql initialization + +from scapy.contrib.mysql import * + += handshakev10 build and parse + +handshake_pkt = MySQLServerPacket(sequence_id=0) / MySQLHandshakeV10( + server_version=b"8.0.36", + connection_id=1337, + auth_plugin_data_part_1=b"12345678", + capability_flags_lower=0xFFFF, + character_set=0x21, + status_flags=0x0002, + capability_flags_upper=0x0018, + auth_plugin_data_len=21, + auth_plugin_data_part_2=b"ABCDEFGHIJKL\x00", + auth_plugin_name=b"mysql_native_password", +) + +handshake_raw = bytes(handshake_pkt) +assert handshake_raw == bytes.fromhex( + "4a000000" + "0a382e302e333600" + "39050000" + "3132333435363738" + "00" + "ffff" + "21" + "0200" + "1800" + "15" + "00000000000000000000" + "4142434445464748494a4b4c00" + "6d7973716c5f6e61746976655f70617373776f726400" +) + +handshake_stream = MySQLServer(handshake_raw) +assert len(handshake_stream.contents) == 1 +handshake = handshake_stream.contents[0].payload +assert isinstance(handshake, MySQLHandshakeV10) +assert handshake.server_version == b"8.0.36" +assert handshake.connection_id == 1337 +assert handshake.capability_flags == 0x0018FFFF +assert handshake.auth_plugin_name == b"mysql_native_password" + += sslrequest build and parse + +ssl_request_pkt = MySQLClientPacket(sequence_id=1) / MySQLSSLRequest( + client_flags=CLIENT_PROTOCOL_41 | CLIENT_SSL, + max_packet_size=0x01000000, + character_set=0x21, +) + +ssl_request_raw = bytes(ssl_request_pkt) +ssl_request = MySQLClient(ssl_request_raw) +assert isinstance(ssl_request.contents[0].payload, MySQLSSLRequest) +assert ssl_request.contents[0].payload.client_flags == ( + CLIENT_PROTOCOL_41 | CLIENT_SSL +) +assert ssl_request.contents[0].payload.max_packet_size == 0x01000000 + += handshake response official example + +handshake_response_raw = bytes.fromhex( + "54000001" + "8da60f00" + "00000001" + "08" + "0000000000000000000000000000000000000000000000" + "70616d00" + "14" + "ab09eef6bcb1323e61143865c0991d957d75d447" + "7465737400" + "6d7973716c5f6e61746976655f70617373776f726400" +) + +handshake_response_stream = MySQLClient(handshake_response_raw) +assert len(handshake_response_stream.contents) == 1 +handshake_response = handshake_response_stream.contents[0].payload +assert isinstance(handshake_response, MySQLHandshakeResponse41) +assert handshake_response.username == b"pam" +assert len(handshake_response.auth_response) == 20 +assert handshake_response.database == b"test" +assert handshake_response.auth_plugin_name == b"mysql_native_password" + += auth switch packets + +auth_switch_request_pkt = MySQLServerPacket(sequence_id=2) / MySQLAuthSwitchRequest( + plugin_name=b"mysql_native_password", + plugin_data=b"1234567890abcdefghij", +) + +auth_switch_request_raw = bytes(auth_switch_request_pkt) +auth_switch_request_stream = MySQLServer(auth_switch_request_raw) +auth_switch_request = auth_switch_request_stream.contents[0].payload +assert isinstance(auth_switch_request, MySQLAuthSwitchRequest) +assert auth_switch_request.header == 0xFE +assert auth_switch_request.plugin_name == b"mysql_native_password" +assert auth_switch_request.plugin_data == b"1234567890abcdefghij" + +auth_switch_response_raw = bytes.fromhex( + "14000003" + "f417961f79f3ac100bdaa6b3b5c20eab5985ffb8" +) +auth_switch_response_stream = MySQLClient(auth_switch_response_raw) +auth_switch_response = auth_switch_response_stream.contents[0].payload +assert isinstance(auth_switch_response, MySQLAuthSwitchResponse) +assert auth_switch_response.data == bytes.fromhex( + "f417961f79f3ac100bdaa6b3b5c20eab5985ffb8" +) + += old auth switch request packet + +old_auth_switch_raw = bytes.fromhex("01000002fe") +old_auth_switch_stream = MySQLServer(old_auth_switch_raw) +old_auth_switch = old_auth_switch_stream.contents[0].payload +assert isinstance(old_auth_switch, MySQLOldAuthSwitchRequest) +assert old_auth_switch.header == 0xFE + += auth more data packet + +auth_more_data_pkt = MySQLServerPacket(sequence_id=4) / MySQLAuthMoreData( + data=b"\x03" +) +auth_more_data_raw = bytes(auth_more_data_pkt) +assert auth_more_data_raw == bytes.fromhex("020000040103") +auth_more_data_stream = MySQLServer(auth_more_data_raw) +auth_more_data = auth_more_data_stream.contents[0].payload +assert isinstance(auth_more_data, MySQLAuthMoreData) +assert auth_more_data.header == 0x01 +assert auth_more_data.data == b"\x03" + += stmt prepare ok and metadata packets + +stmt_prepare_ok_pkt = MySQLServerPacket(sequence_id=1) / MySQLStmtPrepareOK( + statement_id=1, + num_columns=1, + num_params=2, + warning_count=0, +) +param_def_1_pkt = MySQLServerPacket(sequence_id=2) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"", + table=b"", + org_table=b"", + column_name=b"?", + org_column_name=b"?", + fixed_length_fields_len=0x0C, + character_set=0x3F, + column_length=0, + column_type=0xFD, + flags=0x0080, + decimals=0, + filler=0, +) +param_def_2_pkt = MySQLServerPacket(sequence_id=3) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"", + table=b"", + org_table=b"", + column_name=b"?", + org_column_name=b"?", + fixed_length_fields_len=0x0C, + character_set=0x3F, + column_length=0, + column_type=0xFD, + flags=0x0080, + decimals=0, + filler=0, +) +column_def_prepare_pkt = MySQLServerPacket(sequence_id=5) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"demo", + table=b"lots", + org_table=b"lots", + column_name=b"id", + org_column_name=b"id", + fixed_length_fields_len=0x0C, + character_set=0x3F, + column_length=11, + column_type=0x03, + flags=0x4203, + decimals=0, + filler=0, +) +stmt_prepare_stream_raw = ( + bytes(stmt_prepare_ok_pkt) + + bytes(param_def_1_pkt) + + bytes(param_def_2_pkt) + + bytes(MySQLServerPacket(sequence_id=4) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) + + bytes(column_def_prepare_pkt) + + bytes(MySQLServerPacket(sequence_id=6) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) +) +stmt_prepare_stream = MySQLServer(stmt_prepare_stream_raw) +assert len(stmt_prepare_stream.contents) == 6 +assert isinstance(stmt_prepare_stream.contents[0].payload, MySQLStmtPrepareOK) +assert stmt_prepare_stream.contents[0].payload.statement_id == 1 +assert stmt_prepare_stream.contents[0].payload.num_columns == 1 +assert stmt_prepare_stream.contents[0].payload.num_params == 2 +assert isinstance(stmt_prepare_stream.contents[1].payload, MySQLColumnDefinition41) +assert isinstance(stmt_prepare_stream.contents[2].payload, MySQLColumnDefinition41) +assert isinstance(stmt_prepare_stream.contents[3].payload, MySQLEOFPacket) +assert isinstance(stmt_prepare_stream.contents[4].payload, MySQLColumnDefinition41) +assert stmt_prepare_stream.contents[4].payload.column_name == b"id" +assert isinstance(stmt_prepare_stream.contents[5].payload, MySQLEOFPacket) + += generic response packets official examples + +ok_raw = bytes.fromhex("0700000200000002000000") +ok_stream = MySQLServer(ok_raw) +ok_packet = ok_stream.contents[0].payload +assert isinstance(ok_packet, MySQLOKPacket) +assert ok_packet.affected_rows == 0 +assert ok_packet.last_insert_id == 0 +assert ok_packet.status_flags == 0x0002 +assert ok_packet.warnings == 0 + +err_raw = bytes.fromhex( + "17000001" + "ff4804" + "234859303030" + "4e6f207461626c65732075736564" +) +err_stream = MySQLServer(err_raw) +err_packet = err_stream.contents[0].payload +assert isinstance(err_packet, MySQLErrPacket) +assert err_packet.error_code == 1096 +assert err_packet.sql_state == b"HY000" +assert err_packet.error_message == b"No tables used" + +eof_raw = bytes.fromhex("05000005fe00000200") +eof_stream = MySQLServer(eof_raw) +eof_packet = eof_stream.contents[0].payload +assert isinstance(eof_packet, MySQLEOFPacket) +assert eof_packet.warnings == 0 +assert eof_packet.status_flags == 0x0002 + +short_eof_raw = bytes.fromhex("01000003fe") +short_eof_stream = MySQLServer(short_eof_raw) +short_eof_packet = short_eof_stream.contents[0].payload +assert isinstance(short_eof_packet, MySQLEOFPacket) +assert short_eof_packet.header == 0xFE + += text resultset response + +column_count_pkt = MySQLServerPacket(sequence_id=1) / MySQLResultSetColumnCount( + column_count=1 +) +column_def_pkt = MySQLServerPacket(sequence_id=2) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"", + table=b"", + org_table=b"", + column_name=b"version", + org_column_name=b"version", + fixed_length_fields_len=0x0C, + character_set=0x21, + column_length=28, + column_type=0xFD, + flags=0, + decimals=0, + filler=0, +) +row_pkt = MySQLServerPacket(sequence_id=4) / MySQLTextResultSetRow( + values=[b"8.0.36"] +) + +resultset_raw = ( + bytes(column_count_pkt) + + bytes(column_def_pkt) + + eof_raw[:3] + b"\x03" + eof_raw[4:] + + bytes(row_pkt) + + eof_raw +) +resultset_stream = MySQLServer(resultset_raw) +assert len(resultset_stream.contents) == 5 +assert isinstance(resultset_stream.contents[0].payload, MySQLResultSetColumnCount) +assert resultset_stream.contents[0].payload.column_count == 1 +assert isinstance(resultset_stream.contents[1].payload, MySQLColumnDefinition41) +assert resultset_stream.contents[1].payload.column_name == b"version" +assert resultset_stream.contents[1].payload.column_type == 0xFD +assert isinstance(resultset_stream.contents[2].payload, MySQLEOFPacket) +assert isinstance(resultset_stream.contents[3].payload, MySQLTextResultSetRow) +assert resultset_stream.contents[3].payload.values == [b"8.0.36"] +assert isinstance(resultset_stream.contents[4].payload, MySQLEOFPacket) + += field list style response + +field_list_column_1 = MySQLServerPacket(sequence_id=1) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"test", + table=b"test_table", + org_table=b"test_table", + column_name=b"id", + org_column_name=b"id", + fixed_length_fields_len=0x0C, + character_set=0x3F, + column_length=11, + column_type=0x03, + flags=0, + decimals=0, + filler=0, +) +field_list_column_2 = MySQLServerPacket(sequence_id=2) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"test", + table=b"test_table", + org_table=b"test_table", + column_name=b"name", + org_column_name=b"name", + fixed_length_fields_len=0x0C, + character_set=0x08, + column_length=255, + column_type=0xFC, + flags=0x0010, + decimals=0, + filler=0, +) +field_list_raw = ( + bytes(field_list_column_1) + + bytes(field_list_column_2) + + bytes(MySQLServerPacket(sequence_id=3) / MySQLEOFPacket(header=0xFE)) +) +field_list_stream = MySQLServer(field_list_raw) +assert len(field_list_stream.contents) == 3 +assert isinstance(field_list_stream.contents[0].payload, MySQLColumnDefinition41) +assert field_list_stream.contents[0].payload.column_name == b"id" +assert isinstance(field_list_stream.contents[1].payload, MySQLColumnDefinition41) +assert field_list_stream.contents[1].payload.column_name == b"name" +assert isinstance(field_list_stream.contents[2].payload, MySQLEOFPacket) + += text resultset with nulls and multiple columns + +multi_column_count_pkt = MySQLServerPacket(sequence_id=1) / MySQLResultSetColumnCount( + column_count=2 +) +multi_column_def_1_pkt = MySQLServerPacket(sequence_id=2) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"", + table=b"users", + org_table=b"users", + column_name=b"name", + org_column_name=b"name", + fixed_length_fields_len=0x0C, + character_set=0x21, + column_length=64, + column_type=0xFD, + flags=0, + decimals=0, + filler=0, +) +multi_column_def_2_pkt = MySQLServerPacket(sequence_id=3) / MySQLColumnDefinition41( + catalog=b"def", + schema=b"", + table=b"users", + org_table=b"users", + column_name=b"nickname", + org_column_name=b"nickname", + fixed_length_fields_len=0x0C, + character_set=0x21, + column_length=64, + column_type=0xFD, + flags=0, + decimals=0, + filler=0, +) +multi_row_pkt = MySQLServerPacket(sequence_id=5) / MySQLTextResultSetRow( + values=[b"alice", None] +) + +multi_resultset_raw = ( + bytes(multi_column_count_pkt) + + bytes(multi_column_def_1_pkt) + + bytes(multi_column_def_2_pkt) + + bytes(MySQLServerPacket(sequence_id=4) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) + + bytes(multi_row_pkt) + + bytes(MySQLServerPacket(sequence_id=6) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) +) +multi_resultset_stream = MySQLServer(multi_resultset_raw) +assert len(multi_resultset_stream.contents) == 6 +assert isinstance(multi_resultset_stream.contents[0].payload, MySQLResultSetColumnCount) +assert multi_resultset_stream.contents[0].payload.column_count == 2 +assert isinstance(multi_resultset_stream.contents[1].payload, MySQLColumnDefinition41) +assert multi_resultset_stream.contents[1].payload.column_name == b"name" +assert isinstance(multi_resultset_stream.contents[2].payload, MySQLColumnDefinition41) +assert multi_resultset_stream.contents[2].payload.column_name == b"nickname" +assert isinstance(multi_resultset_stream.contents[4].payload, MySQLTextResultSetRow) +assert multi_resultset_stream.contents[4].payload.values == [b"alice", None] +assert isinstance(multi_resultset_stream.contents[5].payload, MySQLEOFPacket) + += text resultset with deprecate eof ok terminators + +ok_terminator_pkt = MySQLServerPacket(sequence_id=3) / MySQLOKPacket( + affected_rows=0, + last_insert_id=0, + status_flags=CLIENT_DEPRECATE_EOF >> 16, + warnings=0, +) +ok_final_pkt = MySQLServerPacket(sequence_id=5) / MySQLOKPacket( + affected_rows=0, + last_insert_id=0, + status_flags=0x0002, + warnings=0, +) + +ok_terminated_resultset_raw = ( + bytes(column_count_pkt) + + bytes(column_def_pkt) + + bytes(ok_terminator_pkt) + + bytes(row_pkt) + + bytes(ok_final_pkt) +) +ok_terminated_resultset_stream = MySQLServer(ok_terminated_resultset_raw) +assert len(ok_terminated_resultset_stream.contents) == 5 +assert isinstance(ok_terminated_resultset_stream.contents[0].payload, MySQLResultSetColumnCount) +assert isinstance(ok_terminated_resultset_stream.contents[1].payload, MySQLColumnDefinition41) +assert isinstance(ok_terminated_resultset_stream.contents[2].payload, MySQLOKPacket) +assert isinstance(ok_terminated_resultset_stream.contents[3].payload, MySQLTextResultSetRow) +assert ok_terminated_resultset_stream.contents[3].payload.values == [b"8.0.36"] +assert isinstance(ok_terminated_resultset_stream.contents[4].payload, MySQLOKPacket) + += simple query build and parse + +query_pkt = MySQLClientPacket(sequence_id=0) / MySQLComQuery( + query=b"SELECT VERSION()" +) +query_raw = bytes(query_pkt) +assert query_raw == bytes.fromhex( + "11000000" + "0353454c4543542056455253494f4e2829" +) + +query_stream = MySQLClient(query_raw) +query_packet = query_stream.contents[0].payload +assert isinstance(query_packet, MySQLComQuery) +assert query_packet.cmd == 0x03 +assert query_packet.query == b"SELECT VERSION()" + += multi packet stream + +combined_server_stream = MySQLServer(handshake_raw + ok_raw + eof_raw) +assert len(combined_server_stream.contents) == 3 +assert isinstance(combined_server_stream.contents[0].payload, MySQLHandshakeV10) +assert isinstance(combined_server_stream.contents[1].payload, MySQLOKPacket) +assert isinstance(combined_server_stream.contents[2].payload, MySQLEOFPacket) From 34fdf6f1a9089fe43483294f5af643c1de02d975 Mon Sep 17 00:00:00 2001 From: pablogonzalezpe Date: Wed, 15 Apr 2026 00:04:53 +0200 Subject: [PATCH 2/2] test/contrib: expand mysql coverage --- test/contrib/mysql.uts | 171 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/test/contrib/mysql.uts b/test/contrib/mysql.uts index f87d94c6cd3..505e10db5ed 100644 --- a/test/contrib/mysql.uts +++ b/test/contrib/mysql.uts @@ -9,6 +9,7 @@ = mysql initialization from scapy.contrib.mysql import * +import scapy.contrib.mysql as mysql = handshakev10 build and parse @@ -451,3 +452,173 @@ assert len(combined_server_stream.contents) == 3 assert isinstance(combined_server_stream.contents[0].payload, MySQLHandshakeV10) assert isinstance(combined_server_stream.contents[1].payload, MySQLOKPacket) assert isinstance(combined_server_stream.contents[2].payload, MySQLEOFPacket) + += field and helper edge cases + +assert mysql._read_lenenc_int(b"\xfc\x34\x12") == (0x1234, 3) +assert mysql._read_lenenc_int(b"\xfd\x56\x34\x12") == (0x123456, 4) +assert mysql._read_lenenc_int(b"\xfe\x08\x07\x06\x05\x04\x03\x02\x01") == ( + 0x0102030405060708, + 9, +) +assert mysql._read_lenenc_int(b"\xff") == (0, 1) +assert mysql._build_lenenc_int(0x1234) == b"\xfc\x34\x12" +assert mysql._build_lenenc_int(0x123456) == b"\xfd\x56\x34\x12" +assert mysql._build_lenenc_int(0x0102030405060708) == ( + b"\xfe\x08\x07\x06\x05\x04\x03\x02\x01" +) +assert mysql._flag_repr(0x0002, mysql.MYSQL_STATUS_FLAGS) == "2 (AUTOCOMMIT)" +assert mysql._flag_repr(0, mysql.MYSQL_STATUS_FLAGS) == "0" + +lenenc_none = mysql.MySQLLenEncIntField("x").addfield(None, b"", None) +assert lenenc_none == b"\x00" +assert mysql.MySQLLenEncIntField("x").i2repr(None, 7) == "7" +assert mysql.MySQLCapabilityFlagsField("flags", 0).i2repr( + None, + CLIENT_PROTOCOL_41 | CLIENT_SSL, +) == "2560 (PROTOCOL_41|SSL)" +assert mysql.MySQLStatusFlagsField("status", 0).i2repr(None, 0x0002) == "2 (AUTOCOMMIT)" +assert mysql.MySQLColumnFlagsField("flags", 0).i2repr(None, 0x0080) == "128 (BINARY)" + +lenenc_str_field = mysql.MySQLLenEncStrField("s") +assert lenenc_str_field.addfield(None, b"", None) == b"\x00" +assert lenenc_str_field.addfield(None, b"", "abc") == b"\x03abc" +assert lenenc_str_field.i2repr(None, b"abc") == "b'abc'" + +text_values_field = mysql.MySQLTextRowValuesField("values") +assert text_values_field.addfield(None, b"", None) == b"" +assert text_values_field.addfield(None, b"", "alice") == b"\x05alice" +assert text_values_field.addfield(None, b"", ["alice", None]) == b"\x05alice\xfb" +assert text_values_field.i2repr(None, [b"alice"]) == "[b'alice']" + += auth response encoding variants + +lenenc_auth_pkt = MySQLClientPacket(sequence_id=1) / MySQLHandshakeResponse41( + client_flags=CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, + max_packet_size=0, + character_set=0x21, + username=b"user", + auth_response=b"abc", +) +lenenc_auth_raw = bytes(lenenc_auth_pkt) +lenenc_auth_stream = MySQLClient(lenenc_auth_raw) +lenenc_auth = lenenc_auth_stream.contents[0].payload +assert isinstance(lenenc_auth, MySQLHandshakeResponse41) +assert lenenc_auth.auth_response == b"abc" + +secure_auth_pkt = MySQLClientPacket(sequence_id=1) / MySQLHandshakeResponse41( + client_flags=CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION, + max_packet_size=0, + character_set=0x21, + username=b"user", + auth_response="abc", +) +secure_auth_raw = bytes(secure_auth_pkt) +secure_auth_stream = MySQLClient(secure_auth_raw) +secure_auth = secure_auth_stream.contents[0].payload +assert isinstance(secure_auth, MySQLHandshakeResponse41) +assert secure_auth.auth_response == b"abc" + +old_auth_pkt = MySQLClientPacket(sequence_id=1) / MySQLHandshakeResponse41( + client_flags=CLIENT_PROTOCOL_41, + max_packet_size=0, + character_set=0x21, + username=b"user", + auth_response="abc", +) +old_auth_raw = bytes(old_auth_pkt) +old_auth_stream = MySQLClient(old_auth_raw) +old_auth = old_auth_stream.contents[0].payload +assert isinstance(old_auth, MySQLHandshakeResponse41) +assert old_auth.auth_response == b"abc" + +empty_secure_auth_pkt = MySQLClientPacket(sequence_id=1) / MySQLHandshakeResponse41( + client_flags=CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION, + max_packet_size=0, + character_set=0x21, + username=b"user", + auth_response=b"", +) +empty_secure_auth_raw = bytes(empty_secure_auth_pkt) +empty_secure_auth_stream = MySQLClient(empty_secure_auth_raw) +assert empty_secure_auth_stream.contents[0].payload.auth_response == b"" + +assert mysql.MySQLAuthResponseField("auth").i2repr(None, b"abc") == "b'abc'" +assert mysql.MySQLAuthResponseField("auth").getfield( + MySQLHandshakeResponse41(client_flags=CLIENT_PROTOCOL_41), + b"abc", +) == (b"", b"abc") +assert mysql.MySQLAuthResponseField("auth").getfield( + MySQLHandshakeResponse41( + client_flags=CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION + ), + b"", +) == (b"", b"") + += dispatcher fallbacks and error transitions + +ping_raw = bytes.fromhex("010000000e") +ping_stream = MySQLClient(ping_raw) +assert isinstance(ping_stream.contents[0].payload, MySQLCommand) +assert ping_stream.contents[0].payload.cmd == 0x0E + +unknown_client_raw = bytes.fromhex("010000007f") +unknown_client_stream = MySQLClient(unknown_client_raw) +assert isinstance(unknown_client_stream.contents[0].payload, Raw) + +unknown_server_raw = bytes.fromhex("02000001fe00") +unknown_server_stream = MySQLServer(unknown_server_raw) +assert isinstance(unknown_server_stream.contents[0].payload, Raw) + +resultset_err_raw = ( + bytes(column_count_pkt) + + bytes(column_def_pkt) + + bytes(MySQLServerPacket(sequence_id=3) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) + + bytes(MySQLServerPacket(sequence_id=4) / MySQLErrPacket(error_code=1064, sql_state=b"42000", error_message=b"syntax error")) +) +resultset_err_stream = MySQLServer(resultset_err_raw) +assert len(resultset_err_stream.contents) == 4 +assert isinstance(resultset_err_stream.contents[3].payload, MySQLErrPacket) + +field_list_err_raw = ( + bytes(field_list_column_1) + + bytes(MySQLServerPacket(sequence_id=2) / MySQLErrPacket(error_code=1091, sql_state=b"42000", error_message=b"bad field list")) +) +field_list_err_stream = MySQLServer(field_list_err_raw) +assert len(field_list_err_stream.contents) == 2 +assert isinstance(field_list_err_stream.contents[0].payload, MySQLColumnDefinition41) +assert isinstance(field_list_err_stream.contents[1].payload, MySQLErrPacket) + +stmt_prepare_columns_only_raw = ( + bytes(MySQLServerPacket(sequence_id=1) / MySQLStmtPrepareOK(statement_id=2, num_columns=1, num_params=0, warning_count=0)) + + bytes(column_def_prepare_pkt) + + bytes(MySQLServerPacket(sequence_id=3) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) +) +stmt_prepare_columns_only_stream = MySQLServer(stmt_prepare_columns_only_raw) +assert len(stmt_prepare_columns_only_stream.contents) == 3 +assert isinstance(stmt_prepare_columns_only_stream.contents[0].payload, MySQLStmtPrepareOK) +assert isinstance(stmt_prepare_columns_only_stream.contents[1].payload, MySQLColumnDefinition41) +assert isinstance(stmt_prepare_columns_only_stream.contents[2].payload, MySQLEOFPacket) + +stmt_prepare_err_raw = ( + bytes(MySQLServerPacket(sequence_id=1) / MySQLStmtPrepareOK(statement_id=3, num_columns=0, num_params=1, warning_count=0)) + + bytes(param_def_1_pkt) + + bytes(MySQLServerPacket(sequence_id=3) / MySQLEOFPacket(warnings=0, status_flags=0x0002)) + + bytes(MySQLServerPacket(sequence_id=4) / MySQLErrPacket(error_code=1243, sql_state=b"HY000", error_message=b"unknown statement")) +) +stmt_prepare_err_stream = MySQLServer(stmt_prepare_err_raw) +assert len(stmt_prepare_err_stream.contents) == 4 +assert isinstance(stmt_prepare_err_stream.contents[0].payload, MySQLStmtPrepareOK) +assert isinstance(stmt_prepare_err_stream.contents[1].payload, MySQLColumnDefinition41) +assert isinstance(stmt_prepare_err_stream.contents[2].payload, MySQLEOFPacket) +assert isinstance(stmt_prepare_err_stream.contents[3].payload, MySQLErrPacket) + += stream completeness and reassembly + +assert mysql._mysql_client_cls(None, [], None, b"\x01\x00\x00") is None +assert mysql._mysql_server_cls(None, [], None, b"\x01\x00\x00") is None +assert mysql._mysql_stream_complete(b"\x01\x00\x00\x00\x0e") +assert not mysql._mysql_stream_complete(b"\x01\x00\x00") +assert not mysql._mysql_stream_complete(b"\x05\x00\x00\x00abc") +assert MySQLClient.tcp_reassemble(b"\x01\x00\x00", {}, None) is None +assert isinstance(MySQLClient.tcp_reassemble(b"\x01\x00\x00\x00\x0e", {}, None), MySQLClient)