From 9aa090127217d3e7e3c71357f430b5cb464ee1e0 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 19 Mar 2026 15:38:22 -0500 Subject: [PATCH 01/19] Add WASM compatibility for componentize-py builds Defer top-level `import jwt` to function scope in auth.py, management/manager.py, and management/utils.py (jwt unavailable in WASM). Catch OSError in mysql/connection.py getpass handling (pwd module unavailable in WASM). Broaden except clause for IPython import in utils/events.py. Add singlestoredb/functions/ext/wasm/ package with udf_handler.py and numpy_stub.py so componentize-py components can `pip install` this branch and import directly from singlestoredb. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/auth.py | 3 +- singlestoredb/functions/ext/wasm/__init__.py | 0 .../functions/ext/wasm/numpy_stub.py | 127 +++++ .../functions/ext/wasm/udf_handler.py | 455 ++++++++++++++++++ singlestoredb/management/manager.py | 2 +- singlestoredb/management/utils.py | 4 +- singlestoredb/mysql/connection.py | 3 +- singlestoredb/utils/events.py | 2 +- 8 files changed, 589 insertions(+), 7 deletions(-) create mode 100644 singlestoredb/functions/ext/wasm/__init__.py create mode 100644 singlestoredb/functions/ext/wasm/numpy_stub.py create mode 100644 singlestoredb/functions/ext/wasm/udf_handler.py diff --git a/singlestoredb/auth.py b/singlestoredb/auth.py index 2e10da285..fe94e3414 100644 --- a/singlestoredb/auth.py +++ b/singlestoredb/auth.py @@ -5,8 +5,6 @@ from typing import Optional from typing import Union -import jwt - # Credential types PASSWORD = 'password' @@ -42,6 +40,7 @@ def __init__( @classmethod def from_token(cls, token: bytes, verify_signature: bool = False) -> 'JSONWebToken': """Validate the contents of the JWT.""" + import jwt info = jwt.decode(token, options={'verify_signature': verify_signature}) if not info.get('sub', None) and not info.get('username', None): diff --git a/singlestoredb/functions/ext/wasm/__init__.py b/singlestoredb/functions/ext/wasm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/singlestoredb/functions/ext/wasm/numpy_stub.py b/singlestoredb/functions/ext/wasm/numpy_stub.py new file mode 100644 index 000000000..e343d43f9 --- /dev/null +++ b/singlestoredb/functions/ext/wasm/numpy_stub.py @@ -0,0 +1,127 @@ +""" +Minimal numpy stub for WASM environment. + +This provides just enough of numpy's interface for singlestoredb's +get_signature() function to work without actual numpy. +""" + + +class _DtypeMeta(type): + """Metaclass for dtype to support typing.get_origin checks.""" + pass + + +class dtype(metaclass=_DtypeMeta): + """Stub dtype class.""" + def __init__(self, spec=None): + self.spec = spec + + def __repr__(self): + return f'dtype({self.spec!r})' + + +class ndarray: + """Stub ndarray class.""" + pass + + +# Stub type classes - these just need to exist for isinstance/issubclass checks +class bool_: + pass + + +class integer: + pass + + +class int_(integer): + pass + + +class int8(integer): + pass + + +class int16(integer): + pass + + +class int32(integer): + pass + + +class int64(integer): + pass + + +class uint(integer): + pass + + +class unsignedinteger(integer): + pass + + +class uint8(unsignedinteger): + pass + + +class uint16(unsignedinteger): + pass + + +class uint32(unsignedinteger): + pass + + +class uint64(unsignedinteger): + pass + + +class longlong(integer): + pass + + +class ulonglong(unsignedinteger): + pass + + +class str_: + pass + + +class bytes_: + pass + + +class float16: + pass + + +class float32: + pass + + +class float64: + pass + + +class double(float64): + pass + + +class float_(float64): + pass + + +class single(float32): + pass + + +class unicode_(str_): + pass + + +# Common aliases +float = float64 +int = int64 diff --git a/singlestoredb/functions/ext/wasm/udf_handler.py b/singlestoredb/functions/ext/wasm/udf_handler.py new file mode 100644 index 000000000..d189bae15 --- /dev/null +++ b/singlestoredb/functions/ext/wasm/udf_handler.py @@ -0,0 +1,455 @@ +""" +Python UDF handler implementing the WIT interface for WASM component. + +This module provides a Python runtime for UDF functions. When compiled +with componentize-py, it becomes a WASM component that can be loaded by +the Rust UDF server. + +Functions are discovered automatically by scanning sys.modules for +@udf-decorated functions. No _exports.py is needed — just import +FunctionHandler from this module in your UDF file and decorate +functions with @udf. +""" +import difflib # noqa: F401 +import inspect +import json +import logging +import os +import sys +import traceback +import types +from typing import Any +from typing import Callable +from typing import Dict +from typing import List + + +# Install numpy stub before importing singlestoredb (which tries to import numpy) +if 'numpy' not in sys.modules: + try: + import numpy # noqa: F401 + except ImportError: + from . import numpy_stub + sys.modules['numpy'] = numpy_stub + +from singlestoredb.functions.signature import get_signature +from singlestoredb.functions.ext.rowdat_1 import load as _load_rowdat_1 +from singlestoredb.functions.ext.rowdat_1 import dump as _dump_rowdat_1 +from singlestoredb.mysql.constants import FIELD_TYPE as ft + +try: + from _singlestoredb_accel import call_function_accel as _call_function_accel + _has_call_accel = True +except Exception: + _has_call_accel = False + + +class _TracingFormatter(logging.Formatter): + """Match Rust tracing-subscriber's colored output format.""" + + _RESET = '\033[0m' + _DIM = '\033[2m' + _BOLD = '\033[1m' + _LEVEL_COLORS = { + 'DEBUG': '\033[34m', # blue + 'INFO': '\033[32m', # green + 'WARNING': '\033[33m', # yellow + 'ERROR': '\033[31m', # red + 'CRITICAL': '\033[31m', # red + } + + def formatTime(self, record, datefmt=None): + from datetime import datetime, timezone + dt = datetime.fromtimestamp(record.created, tz=timezone.utc) + return dt.strftime('%Y-%m-%dT%H:%M:%S.') + f'{dt.microsecond:06d}Z' + + def format(self, record): + ts = self.formatTime(record) + color = self._LEVEL_COLORS.get(record.levelname, '') + level = f'{color}{self._BOLD}{record.levelname:>5}{self._RESET}' + name = f'{self._DIM}{record.name}{self._RESET}' + msg = record.getMessage() + return f'{self._DIM}{ts}{self._RESET} {level} {name}: {msg}' + + +_handler = logging.StreamHandler() +_handler.setFormatter(_TracingFormatter()) +logging.basicConfig(level=logging.INFO, handlers=[_handler]) +logger = logging.getLogger('udf_handler') + +# Map dtype strings to rowdat_1 type codes for wire serialization. +# rowdat_1 always uses 8-byte encoding for integers and doubles for floats, +# so all int types collapse to LONGLONG and all float types to DOUBLE. +# Uses negative values for unsigned ints / binary data. +rowdat_1_type_map = { + 'bool': ft.LONGLONG, + 'int8': ft.LONGLONG, + 'int16': ft.LONGLONG, + 'int32': ft.LONGLONG, + 'int64': ft.LONGLONG, + 'uint8': -ft.LONGLONG, + 'uint16': -ft.LONGLONG, + 'uint32': -ft.LONGLONG, + 'uint64': -ft.LONGLONG, + 'float32': ft.DOUBLE, + 'float64': ft.DOUBLE, + 'str': ft.STRING, + 'bytes': -ft.STRING, +} + + +class FunctionRegistry: + """Registry of discovered UDF functions.""" + + def __init__(self): + self.functions: Dict[str, Dict[str, Any]] = {} + + def initialize(self) -> None: + """Initialize and discover UDF functions from loaded modules. + + Scans sys.modules for any module containing @udf-decorated + functions. No _exports.py is needed — modules just need to be + imported before initialize() is called (componentize-py captures + them at build time). + """ + self._discover_udf_functions() + + @staticmethod + def _is_stdlib_or_infra(mod_name: str, mod_file: str) -> bool: + """Check if a module is stdlib or infrastructure (not user UDF code). + + Uses the module's __file__ path to detect stdlib modules + (under sys.prefix but not in site-packages) rather than + maintaining a hardcoded list of names. + """ + # Infrastructure modules that are part of this project + _infra = frozenset({ + 'udf_handler', 'numpy_stub', + }) + if mod_name in _infra: + return True + + # Resolve symlinks for reliable prefix comparison + real_file = os.path.realpath(mod_file) + real_prefix = os.path.realpath(sys.prefix) + + # Modules under sys.prefix but NOT in site-packages are stdlib + if real_file.startswith(real_prefix + os.sep): + if 'site-packages' not in real_file: + return True + + return False + + def _discover_udf_functions(self) -> None: + """Discover @udf functions by scanning sys.modules. + + Uses a two-pass approach: first, identify candidate modules + that import FunctionHandler (the convention for UDF modules). + Then extract @udf-decorated functions from those modules. + Modules without a __file__ (built-in/frozen) and stdlib/ + infrastructure modules are skipped automatically. + """ + found_modules = [] + for mod_name, mod in list(sys.modules.items()): + if mod is None: + continue + if not isinstance(mod, types.ModuleType): + continue + mod_file = getattr(mod, '__file__', None) + if mod_file is None: + continue + + # Short-circuit: only scan modules that import + # FunctionHandler (the convention for UDF modules) + if not any( + obj is FunctionHandler + for obj in vars(mod).values() + ): + continue + + # Skip stdlib and infrastructure modules + if self._is_stdlib_or_infra(mod_name, mod_file): + continue + + self._extract_functions(mod) + if any( + hasattr(obj, '_singlestoredb_attrs') + for _, obj in inspect.getmembers(mod) + if inspect.isfunction(obj) + ): + found_modules.append(mod_name) + + if found_modules: + logger.info( + f'Discovered UDF functions from modules: ' + f'{", ".join(sorted(found_modules))}', + ) + else: + logger.warning( + 'No modules with @udf functions found in sys.modules.', + ) + + def _extract_functions(self, module: Any) -> None: + """Extract @udf-decorated functions from a module. + + Unlike module scanning, this does not filter by __module__ — + _exports.py may re-export functions defined in other modules. + """ + for name, obj in inspect.getmembers(module): + if name.startswith('_'): + continue + + if not callable(obj): + continue + + if not inspect.isfunction(obj): + continue + + # Only register functions decorated with @udf + if not hasattr(obj, '_singlestoredb_attrs'): + continue + + try: + sig = get_signature(obj) + if sig and sig.get('args') is not None and sig.get('returns'): + self._register_function(obj, name, sig) + except (TypeError, ValueError): + # Skip functions that can't be introspected + pass + + def _build_json_descriptions( + self, + func_names: List[str], + ) -> List[Dict[str, Any]]: + """Build JSON-serializable descriptions for the given function names. + + Extracts metadata from the stored signature dict for each function. + """ + descriptions = [] + for func_name in func_names: + func_info = self.functions[func_name] + sig = func_info['signature'] + args = [] + for arg in sig['args']: + args.append({ + 'name': arg['name'], + 'dtype': arg['dtype'], + 'sql': arg['sql'], + }) + returns = [] + for ret in sig['returns']: + returns.append({ + 'name': ret.get('name') or None, + 'dtype': ret['dtype'], + 'sql': ret['sql'], + }) + descriptions.append({ + 'name': func_name, + 'args': args, + 'returns': returns, + 'args_data_format': sig.get('args_data_format') or 'scalar', + 'returns_data_format': ( + sig.get('returns_data_format') or 'scalar' + ), + 'function_type': sig.get('function_type') or 'udf', + 'doc': sig.get('doc'), + }) + return descriptions + + def create_function( + self, + signature_json: str, + code: str, + replace: bool, + ) -> List[str]: + """Register a function from its signature and Python source code. + + Args: + signature_json: JSON object matching the describe-functions + element schema (must contain a 'name' field) + code: Python source code containing the @udf-decorated function + replace: If False, raise an error if the function already exists + + Returns: + List of newly registered function names + + Raises: + SyntaxError: If the code has syntax errors + ValueError: If no @udf-decorated functions are found or + function already exists and replace is False + """ + sig = json.loads(signature_json) + func_name = sig.get('name') + if not func_name: + raise ValueError( + 'signature JSON must contain a "name" field', + ) + + # Check for name collision when replace is False + if not replace and func_name in self.functions: + raise ValueError( + f'Function "{func_name}" already exists ' + f'(use replace=true to overwrite)', + ) + + # When replacing, remove the old entry so the new registration + # is detected as "new" by the before/after name comparison. + if replace and func_name in self.functions: + del self.functions[func_name] + + # Use __main__ as the module name for dynamically submitted functions + name = '__main__' + + # Validate syntax + compiled = compile(code, f'<{name}>', 'exec') + + # Reuse existing module to avoid corrupting the componentize-py + # runtime state (replacing sys.modules['__main__'] traps WASM). + if name in sys.modules: + module = sys.modules[name] + else: + module = types.ModuleType(name) + module.__file__ = f'<{name}>' + sys.modules[name] = module + exec(compiled, module.__dict__) # noqa: S102 + + # Extract functions from the module + before_names = set(self.functions.keys()) + self._extract_functions(module) + new_names = [k for k in self.functions if k not in before_names] + + if not new_names: + raise ValueError( + 'No @udf-decorated functions found in submitted code', + ) + + logger.info( + f'create_function({func_name}): registered ' + f'{len(new_names)} function(s): {", ".join(new_names)}', + ) + return new_names + + def _register_function( + self, + func: Callable, + func_name: str, + sig: Dict[str, Any], + ) -> None: + """Register a function under its bare name. + + All functions are registered as top-level names (no module prefix). + If a function with the same name already exists, the last + registration wins. + """ + # Use alias name from signature if available, otherwise use function name + full_name = sig.get('name') or func_name + + # Convert args to (name, type_code) tuples + arg_types = [] + for arg in sig['args']: + dtype = arg['dtype'].replace('?', '') + if dtype not in rowdat_1_type_map: + logger.warning(f"Skipping {full_name}: unsupported arg dtype '{dtype}'") + return + arg_types.append((arg['name'], rowdat_1_type_map[dtype])) + + # Convert returns to type_code list + return_types = [] + for ret in sig['returns']: + dtype = ret['dtype'].replace('?', '') + if dtype not in rowdat_1_type_map: + logger.warning(f'Skipping {full_name}: no type mapping for {dtype}') + return + return_types.append(rowdat_1_type_map[dtype]) + + self.functions[full_name] = { + 'func': func, + 'arg_types': arg_types, + 'return_types': return_types, + 'signature': sig, + } + + +# Global registry instance +_registry = FunctionRegistry() + + +class FunctionHandler: + """Implementation of the singlestore:udf/function-handler interface.""" + + def initialize(self) -> None: + """Initialize and discover UDF functions from loaded modules.""" + if _has_call_accel: + logger.info('Using accelerated C call_function_accel loop') + else: + logger.info('Using pure Python call_function loop') + _registry.initialize() + + def call_function(self, name: str, input_data: bytes) -> bytes: + """Call a function by its registered name.""" + if name not in _registry.functions: + raise ValueError(f'unknown function: {name}') + + func_info = _registry.functions[name] + func = func_info['func'] + arg_types = func_info['arg_types'] + return_types = func_info['return_types'] + + try: + if _has_call_accel: + return _call_function_accel( + colspec=arg_types, + returns=return_types, + data=input_data, + func=func, + ) + + # Fallback to pure Python + row_ids, rows = _load_rowdat_1(arg_types, input_data) + results = [] + for row in rows: + result = func(*row) + if not isinstance(result, tuple): + result = [result] + results.append(list(result)) + return bytes(_dump_rowdat_1(return_types, row_ids, results)) + + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'Error calling {name}: {e}\n{tb}') + + def describe_functions(self) -> str: + """Describe all functions as a JSON array. + + Returns a JSON string containing an array of function description + objects with: name, args, returns, args_data_format, + returns_data_format, function_type, doc. + + Raises RuntimeError on failure (mapped to result Err by + componentize-py). + """ + try: + func_names = list(_registry.functions.keys()) + descriptions = _registry._build_json_descriptions(func_names) + return json.dumps(descriptions) + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'{e}\n{tb}') + + def create_function( + self, + signature: str, + code: str, + replace: bool, + ) -> None: + """Register a function from its signature and Python source code. + + Returns None on success (mapped to result Ok(()) by componentize-py). + Raises RuntimeError on failure (mapped to result Err by + componentize-py). + """ + try: + _registry.create_function(signature, code, replace) + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'{e}\n{tb}') diff --git a/singlestoredb/management/manager.py b/singlestoredb/management/manager.py index 144dbb3e3..575df0876 100644 --- a/singlestoredb/management/manager.py +++ b/singlestoredb/management/manager.py @@ -10,7 +10,6 @@ from typing import Union from urllib.parse import urljoin -import jwt import requests from .. import config @@ -33,6 +32,7 @@ def set_organization(kwargs: Dict[str, Any]) -> None: def is_jwt(token: str) -> bool: """Is the given token a JWT?""" + import jwt try: jwt.decode(token, options={'verify_signature': False}) return True diff --git a/singlestoredb/management/utils.py b/singlestoredb/management/utils.py index 5aa072e41..ea0e04d21 100644 --- a/singlestoredb/management/utils.py +++ b/singlestoredb/management/utils.py @@ -18,8 +18,6 @@ from typing import Union from urllib.parse import urlparse -import jwt - from .. import converters from ..config import get_option from ..utils import events @@ -151,6 +149,7 @@ def handle_connection_info(msg: Dict[str, Any]) -> None: def retrieve_current_authentication_info() -> List[Tuple[str, Any]]: """Retrieve JWT if not expired.""" + import jwt nonlocal authentication_info password = authentication_info.get('password') if password: @@ -198,6 +197,7 @@ def get_authentication_info(include_env: bool = True) -> Dict[str, Any]: def get_token() -> Optional[str]: """Return the token for the Management API.""" + import jwt # See if an API key is configured tok = get_option('management.token') if tok: diff --git a/singlestoredb/mysql/connection.py b/singlestoredb/mysql/connection.py index 094fad684..2d4fdd2a9 100644 --- a/singlestoredb/mysql/connection.py +++ b/singlestoredb/mysql/connection.py @@ -87,8 +87,9 @@ DEFAULT_USER = getpass.getuser() del getpass -except (ImportError, KeyError): +except (ImportError, KeyError, OSError): # KeyError occurs when there's no entry in OS database for a current user. + # OSError occurs in WASM environments where pwd module is unavailable. DEFAULT_USER = None DEBUG = get_option('debug.connection') diff --git a/singlestoredb/utils/events.py b/singlestoredb/utils/events.py index dab01f08f..2054b09d1 100644 --- a/singlestoredb/utils/events.py +++ b/singlestoredb/utils/events.py @@ -7,7 +7,7 @@ try: from IPython import get_ipython has_ipython = True -except ImportError: +except Exception: has_ipython = False From 73dcaad06bc6aa149ad4fec00f37a56f68482be3 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 19 Mar 2026 15:44:04 -0500 Subject: [PATCH 02/19] Add WIT interface definition for WASM UDF components Required by componentize-py to build function-handler components. Co-Authored-By: Claude Opus 4.6 --- wit/udf.wit | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 wit/udf.wit diff --git a/wit/udf.wit b/wit/udf.wit new file mode 100644 index 000000000..e06e35a41 --- /dev/null +++ b/wit/udf.wit @@ -0,0 +1,26 @@ +package singlestore:udf; + +interface function-handler { + /// Initialize the handler and discover pre-imported UDF modules. + /// Must be called before any other functions. + initialize: func() -> result<_, string>; + + /// Call a function by its registered name (e.g., "my_func") + /// Input/output are rowdat_1 binary format + call-function: func(name: string, input: list) -> result, string>; + + /// Describe all registered functions as a JSON array of objects. + /// Each object has: name, args [{name, dtype, sql}], returns [{name, dtype, sql}], + /// args_data_format, returns_data_format, function_type, doc + describe-functions: func() -> result; + + /// Register a function from its signature and Python source code. + /// `signature` is a JSON object matching the describe-functions element schema. + /// `code` is the Python source containing the @udf-decorated function. + /// `replace` controls whether an existing function of the same name is overwritten. + create-function: func(signature: string, code: string, replace: bool) -> result<_, string>; +} + +world external-udf { + export function-handler; +} From 49521d1f145d7f8ad38bda7a3fabae9a15a7b334 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 20 Mar 2026 12:55:53 -0500 Subject: [PATCH 03/19] Add type annotations to WASM numpy stub and UDF handler Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/wasm/numpy_stub.py | 4 ++-- singlestoredb/functions/ext/wasm/udf_handler.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/singlestoredb/functions/ext/wasm/numpy_stub.py b/singlestoredb/functions/ext/wasm/numpy_stub.py index e343d43f9..386c4a90d 100644 --- a/singlestoredb/functions/ext/wasm/numpy_stub.py +++ b/singlestoredb/functions/ext/wasm/numpy_stub.py @@ -13,10 +13,10 @@ class _DtypeMeta(type): class dtype(metaclass=_DtypeMeta): """Stub dtype class.""" - def __init__(self, spec=None): + def __init__(self, spec: object = None) -> None: self.spec = spec - def __repr__(self): + def __repr__(self) -> str: return f'dtype({self.spec!r})' diff --git a/singlestoredb/functions/ext/wasm/udf_handler.py b/singlestoredb/functions/ext/wasm/udf_handler.py index d189bae15..fa4ab4128 100644 --- a/singlestoredb/functions/ext/wasm/udf_handler.py +++ b/singlestoredb/functions/ext/wasm/udf_handler.py @@ -58,12 +58,12 @@ class _TracingFormatter(logging.Formatter): 'CRITICAL': '\033[31m', # red } - def formatTime(self, record, datefmt=None): + def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str: from datetime import datetime, timezone dt = datetime.fromtimestamp(record.created, tz=timezone.utc) return dt.strftime('%Y-%m-%dT%H:%M:%S.') + f'{dt.microsecond:06d}Z' - def format(self, record): + def format(self, record: logging.LogRecord) -> str: ts = self.formatTime(record) color = self._LEVEL_COLORS.get(record.levelname, '') level = f'{color}{self._BOLD}{record.levelname:>5}{self._RESET}' @@ -101,7 +101,7 @@ def format(self, record): class FunctionRegistry: """Registry of discovered UDF functions.""" - def __init__(self): + def __init__(self) -> None: self.functions: Dict[str, Dict[str, Any]] = {} def initialize(self) -> None: @@ -331,7 +331,7 @@ def create_function( def _register_function( self, - func: Callable, + func: Callable[..., Any], func_name: str, sig: Dict[str, Any], ) -> None: From 5d1e9e63f70ed146b8e7e161a0e710d4d74e5a2c Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 23 Mar 2026 11:20:23 -0500 Subject: [PATCH 04/19] Add code generation for UDF handler function registration Build complete @udf-decorated Python functions from signature metadata and raw function body instead of requiring full source code. This adds dtype-to-Python type mapping and constructs properly annotated functions at registration time. Co-Authored-By: Claude Opus 4.6 --- .../functions/ext/wasm/udf_handler.py | 108 ++++++++++++++++-- 1 file changed, 101 insertions(+), 7 deletions(-) diff --git a/singlestoredb/functions/ext/wasm/udf_handler.py b/singlestoredb/functions/ext/wasm/udf_handler.py index fa4ab4128..eb4f13feb 100644 --- a/singlestoredb/functions/ext/wasm/udf_handler.py +++ b/singlestoredb/functions/ext/wasm/udf_handler.py @@ -97,6 +97,25 @@ def format(self, record: logging.LogRecord) -> str: 'bytes': -ft.STRING, } +# Map dtype strings to Python type annotation strings for code generation. +_dtype_to_python: Dict[str, str] = { + 'bool': 'bool', + 'int8': 'int', + 'int16': 'int', + 'int32': 'int', + 'int64': 'int', + 'int': 'int', + 'uint8': 'int', + 'uint16': 'int', + 'uint32': 'int', + 'uint64': 'int', + 'float32': 'float', + 'float64': 'float', + 'float': 'float', + 'str': 'str', + 'bytes': 'bytes', +} + class FunctionRegistry: """Registry of discovered UDF functions.""" @@ -256,27 +275,98 @@ def _build_json_descriptions( }) return descriptions + @staticmethod + def _python_type_annotation(dtype: str) -> str: + """Convert a dtype string to a Python type annotation. + + Handles nullable types (trailing '?') by wrapping in Optional. + """ + nullable = dtype.endswith('?') + base = dtype.rstrip('?') + py_type = _dtype_to_python.get(base) + if py_type is None: + raise ValueError(f'Unsupported dtype: {dtype!r}') + if nullable: + return f'Optional[{py_type}]' + return py_type + + @staticmethod + def _build_python_code( + sig: Dict[str, Any], + body: str, + ) -> str: + """Build a complete @udf-decorated Python function from signature and body. + + Args: + sig: Parsed signature dict with 'name', 'args', 'returns'. + body: The function body (e.g. "return x * 3"). + + Returns: + Complete Python source with imports and a @udf-decorated function. + """ + func_name = sig['name'] + args = sig.get('args', []) + returns = sig.get('returns', []) + + # Build parameter list with type annotations + params = [] + for arg in args: + ann = FunctionRegistry._python_type_annotation(arg['dtype']) + params.append(f'{arg["name"]}: {ann}') + params_str = ', '.join(params) + + # Build return type annotation + if len(returns) == 0: + ret_ann = 'None' + elif len(returns) == 1: + ret_ann = FunctionRegistry._python_type_annotation( + returns[0]['dtype'], + ) + else: + parts = [ + FunctionRegistry._python_type_annotation(r['dtype']) + for r in returns + ] + ret_ann = f'Tuple[{", ".join(parts)}]' + + # Indent body lines + indented_body = '\n'.join( + f' {line}' for line in body.splitlines() + ) + + return ( + 'from singlestoredb.functions import udf\n' + 'from typing import Optional, Tuple\n' + '\n' + '@udf\n' + f'def {func_name}({params_str}) -> {ret_ann}:\n' + f'{indented_body}\n' + ) + def create_function( self, signature_json: str, code: str, replace: bool, ) -> List[str]: - """Register a function from its signature and Python source code. + """Register a function from its signature and function body. + + Constructs a complete @udf-decorated Python function from the + signature metadata and the raw function body, then compiles + and executes it. Args: signature_json: JSON object matching the describe-functions element schema (must contain a 'name' field) - code: Python source code containing the @udf-decorated function + code: Function body (e.g. "return x * 3"), not full source replace: If False, raise an error if the function already exists Returns: List of newly registered function names Raises: - SyntaxError: If the code has syntax errors - ValueError: If no @udf-decorated functions are found or - function already exists and replace is False + SyntaxError: If the generated code has syntax errors + ValueError: If the function already exists and replace is False """ sig = json.loads(signature_json) func_name = sig.get('name') @@ -297,11 +387,14 @@ def create_function( if replace and func_name in self.functions: del self.functions[func_name] + # Build a complete @udf-decorated function from signature + body + full_code = self._build_python_code(sig, code) + # Use __main__ as the module name for dynamically submitted functions name = '__main__' # Validate syntax - compiled = compile(code, f'<{name}>', 'exec') + compiled = compile(full_code, f'<{name}>', 'exec') # Reuse existing module to avoid corrupting the componentize-py # runtime state (replacing sys.modules['__main__'] traps WASM). @@ -320,7 +413,8 @@ def create_function( if not new_names: raise ValueError( - 'No @udf-decorated functions found in submitted code', + f'Function "{func_name}" was not registered. ' + f'Check that the signature dtypes are supported.', ) logger.info( From aed95df9b30527860b903d87fd5828981cc2ec7e Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 23 Mar 2026 12:54:21 -0500 Subject: [PATCH 05/19] Replace top-level optional imports with lazy import helpers Heavy optional dependencies (numpy, pandas, polars, pyarrow) were imported at module load time, causing failures in WASM environments where these packages may not be available. This adds a lazy import utility module and converts all eager try/except import patterns to use cached lazy accessors. Type maps in dtypes.py are also converted from module-level dicts to lru_cached factory functions. The pandas DataFrame isinstance check in connection.py is replaced with a duck-type hasattr check to avoid importing pandas at module scope. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/connection.py | 8 +- singlestoredb/converters.py | 76 ++++++++------- singlestoredb/functions/dtypes.py | 7 +- singlestoredb/functions/ext/json.py | 15 ++- singlestoredb/functions/ext/rowdat_1.py | 21 ++-- singlestoredb/tests/test_connection.py | 34 +++---- singlestoredb/utils/_lazy_import.py | 42 ++++++++ singlestoredb/utils/dtypes.py | 58 ++++++------ singlestoredb/utils/results.py | 121 +++++++++++++----------- 9 files changed, 220 insertions(+), 162 deletions(-) create mode 100644 singlestoredb/utils/_lazy_import.py diff --git a/singlestoredb/connection.py b/singlestoredb/connection.py index 942b2feb9..01d8c2463 100644 --- a/singlestoredb/connection.py +++ b/singlestoredb/connection.py @@ -25,12 +25,6 @@ from urllib.parse import urlparse import sqlparams -try: - from pandas import DataFrame -except ImportError: - class DataFrame(object): # type: ignore - def itertuples(self, *args: Any, **kwargs: Any) -> None: - pass from . import auth from . import exceptions @@ -1175,7 +1169,7 @@ def _iquery( out = list(cur.fetchall()) if not out: return [] - if isinstance(out, DataFrame): + if hasattr(out, 'to_dict') and callable(getattr(out, 'to_dict')): out = out.to_dict(orient='records') elif isinstance(out[0], (tuple, list)): if cur.description: diff --git a/singlestoredb/converters.py b/singlestoredb/converters.py index ec9b73580..818c18ed3 100644 --- a/singlestoredb/converters.py +++ b/singlestoredb/converters.py @@ -26,11 +26,7 @@ except (AttributeError, ImportError): has_pygeos = False -try: - import numpy - has_numpy = True -except ImportError: - has_numpy = False +from .utils._lazy_import import get_numpy try: import bson @@ -563,8 +559,9 @@ def float32_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.float32) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.float32) return map(float, json_loads(x)) @@ -591,8 +588,9 @@ def float32_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.float32) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.float32) return struct.unpack(f'<{len(x)//4}f', x) @@ -619,8 +617,9 @@ def float16_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.float16) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.float16) return map(float, json_loads(x)) @@ -647,8 +646,9 @@ def float16_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.float16) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.float16) return struct.unpack(f'<{len(x)//2}e', x) @@ -675,8 +675,9 @@ def float64_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.float64) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.float64) return map(float, json_loads(x)) @@ -703,8 +704,9 @@ def float64_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.float64) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.float64) return struct.unpack(f'<{len(x)//8}d', x) @@ -731,8 +733,9 @@ def int8_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int8) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int8) return map(int, json_loads(x)) @@ -759,8 +762,9 @@ def int8_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int8) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int8) return struct.unpack(f'<{len(x)}b', x) @@ -787,8 +791,9 @@ def int16_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int16) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int16) return map(int, json_loads(x)) @@ -815,8 +820,9 @@ def int16_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int16) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int16) return struct.unpack(f'<{len(x)//2}h', x) @@ -843,8 +849,9 @@ def int32_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int32) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int32) return map(int, json_loads(x)) @@ -871,8 +878,9 @@ def int32_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int32) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int32) return struct.unpack(f'<{len(x)//4}l', x) @@ -899,8 +907,9 @@ def int64_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int64) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int64) return map(int, json_loads(x)) @@ -928,8 +937,9 @@ def int64_vector_or_none(x: Optional[bytes]) -> Optional[Any]: return None # Bytes - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int64) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int64) return struct.unpack(f'<{len(x)//8}l', x) diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index 0fe26a452..80aabd615 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -11,10 +11,9 @@ from ..converters import converters from ..mysql.converters import escape_item # type: ignore from ..utils.dtypes import DEFAULT_VALUES # noqa -from ..utils.dtypes import NUMPY_TYPE_MAP # noqa -from ..utils.dtypes import PANDAS_TYPE_MAP # noqa -from ..utils.dtypes import POLARS_TYPE_MAP # noqa -from ..utils.dtypes import PYARROW_TYPE_MAP # noqa +from ..utils.dtypes import get_numpy_type_map # noqa +from ..utils.dtypes import get_polars_type_map # noqa +from ..utils.dtypes import get_pyarrow_type_map # noqa DataType = Union[str, Callable[..., Any]] diff --git a/singlestoredb/functions/ext/json.py b/singlestoredb/functions/ext/json.py index 05710247d..619c3ad7e 100644 --- a/singlestoredb/functions/ext/json.py +++ b/singlestoredb/functions/ext/json.py @@ -7,10 +7,9 @@ from typing import TYPE_CHECKING from ..dtypes import DEFAULT_VALUES -from ..dtypes import NUMPY_TYPE_MAP -from ..dtypes import PANDAS_TYPE_MAP -from ..dtypes import POLARS_TYPE_MAP -from ..dtypes import PYARROW_TYPE_MAP +from ..dtypes import get_numpy_type_map +from ..dtypes import get_polars_type_map +from ..dtypes import get_pyarrow_type_map from ..dtypes import PYTHON_CONVERTERS if TYPE_CHECKING: @@ -140,7 +139,7 @@ def load_pandas( ( pd.Series( data, index=index, name=spec[0], - dtype=PANDAS_TYPE_MAP[spec[1]], + dtype=get_numpy_type_map()[spec[1]], ), pd.Series(mask, index=index, dtype=np.longlong), ) @@ -172,7 +171,7 @@ def load_polars( return pl.Series(None, row_ids, dtype=pl.Int64), \ [ ( - pl.Series(spec[0], data, dtype=POLARS_TYPE_MAP[spec[1]]), + pl.Series(spec[0], data, dtype=get_polars_type_map()[spec[1]]), pl.Series(None, mask, dtype=pl.Boolean), ) for (data, mask), spec in zip(cols, colspec) @@ -203,7 +202,7 @@ def load_numpy( return np.asarray(row_ids, dtype=np.longlong), \ [ ( - np.asarray(data, dtype=NUMPY_TYPE_MAP[spec[1]]), # type: ignore + np.asarray(data, dtype=get_numpy_type_map()[spec[1]]), # type: ignore np.asarray(mask, dtype=np.bool_), # type: ignore ) for (data, mask), spec in zip(cols, colspec) @@ -235,7 +234,7 @@ def load_arrow( [ ( pa.array( - data, type=PYARROW_TYPE_MAP[dtype], + data, type=get_pyarrow_type_map()[dtype], mask=pa.array(mask, type=pa.bool_()), ), pa.array(mask, type=pa.bool_()), diff --git a/singlestoredb/functions/ext/rowdat_1.py b/singlestoredb/functions/ext/rowdat_1.py index 94e966b77..c406d1a68 100644 --- a/singlestoredb/functions/ext/rowdat_1.py +++ b/singlestoredb/functions/ext/rowdat_1.py @@ -12,10 +12,9 @@ from ...config import get_option from ...mysql.constants import FIELD_TYPE as ft from ..dtypes import DEFAULT_VALUES -from ..dtypes import NUMPY_TYPE_MAP -from ..dtypes import PANDAS_TYPE_MAP -from ..dtypes import POLARS_TYPE_MAP -from ..dtypes import PYARROW_TYPE_MAP +from ..dtypes import get_numpy_type_map +from ..dtypes import get_polars_type_map +from ..dtypes import get_pyarrow_type_map if TYPE_CHECKING: try: @@ -212,7 +211,7 @@ def _load_pandas( index = pd.Series(row_ids) return pd.Series(row_ids, dtype=np.int64), [ ( - pd.Series(data, index=index, name=name, dtype=PANDAS_TYPE_MAP[dtype]), + pd.Series(data, index=index, name=name, dtype=get_numpy_type_map()[dtype]), pd.Series(mask, index=index, dtype=np.bool_), ) for (data, mask), (name, dtype) in zip(cols, colspec) @@ -247,7 +246,7 @@ def _load_polars( return pl.Series(None, row_ids, dtype=pl.Int64), \ [ ( - pl.Series(name=name, values=data, dtype=POLARS_TYPE_MAP[dtype]), + pl.Series(name=name, values=data, dtype=get_polars_type_map()[dtype]), pl.Series(values=mask, dtype=pl.Boolean), ) for (data, mask), (name, dtype) in zip(cols, colspec) @@ -282,7 +281,7 @@ def _load_numpy( return np.asarray(row_ids, dtype=np.int64), \ [ ( - np.asarray(data, dtype=NUMPY_TYPE_MAP[dtype]), # type: ignore + np.asarray(data, dtype=get_numpy_type_map()[dtype]), # type: ignore np.asarray(mask, dtype=np.bool_), # type: ignore ) for (data, mask), (name, dtype) in zip(cols, colspec) @@ -318,7 +317,7 @@ def _load_arrow( [ ( pa.array( - data, type=PYARROW_TYPE_MAP[dtype], + data, type=get_pyarrow_type_map()[dtype], mask=pa.array(mask, type=pa.bool_()), ), pa.array(mask, type=pa.bool_()), @@ -565,7 +564,7 @@ def _load_pandas_accel( numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( - pd.Series(data, name=name, dtype=PANDAS_TYPE_MAP[dtype]), + pd.Series(data, name=name, dtype=get_numpy_type_map()[dtype]), pd.Series(mask, dtype=np.bool_), ) for (name, dtype), (data, mask) in zip(colspec, numpy_cols) @@ -610,7 +609,7 @@ def _load_polars_accel( pl.Series( name=name, values=data.tolist() if dtype in string_types or dtype in binary_types else data, - dtype=POLARS_TYPE_MAP[dtype], + dtype=get_polars_type_map()[dtype], ), pl.Series(values=mask, dtype=pl.Boolean), ) @@ -653,7 +652,7 @@ def _load_arrow_accel( numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( - pa.array(data, type=PYARROW_TYPE_MAP[dtype], mask=mask), + pa.array(data, type=get_pyarrow_type_map()[dtype], mask=mask), pa.array(mask, type=pa.bool_()), ) for (data, mask), (name, dtype) in zip(numpy_cols, colspec) diff --git a/singlestoredb/tests/test_connection.py b/singlestoredb/tests/test_connection.py index ee392d06a..2ae5cf1d2 100755 --- a/singlestoredb/tests/test_connection.py +++ b/singlestoredb/tests/test_connection.py @@ -22,8 +22,10 @@ try: import pandas as pd has_pandas = True + _pd_str_dtype = str(pd.DataFrame({'a': ['x']}).dtypes['a']) except ImportError: has_pandas = False + _pd_str_dtype = 'object' class TestConnection(unittest.TestCase): @@ -1124,21 +1126,21 @@ def test_alltypes_pandas(self): ('timestamp', 'datetime64[us]'), ('timestamp_6', 'datetime64[us]'), ('year', 'float64'), - ('char_100', 'object'), + ('char_100', _pd_str_dtype), ('binary_100', 'object'), - ('varchar_200', 'object'), + ('varchar_200', _pd_str_dtype), ('varbinary_200', 'object'), - ('longtext', 'object'), - ('mediumtext', 'object'), - ('text', 'object'), - ('tinytext', 'object'), + ('longtext', _pd_str_dtype), + ('mediumtext', _pd_str_dtype), + ('text', _pd_str_dtype), + ('tinytext', _pd_str_dtype), ('longblob', 'object'), ('mediumblob', 'object'), ('blob', 'object'), ('tinyblob', 'object'), ('json', 'object'), - ('enum', 'object'), - ('set', 'object'), + ('enum', _pd_str_dtype), + ('set', _pd_str_dtype), ('bit', 'object'), ] @@ -1266,21 +1268,21 @@ def test_alltypes_no_nulls_pandas(self): ('timestamp', 'datetime64[us]'), ('timestamp_6', 'datetime64[us]'), ('year', 'int16'), - ('char_100', 'object'), + ('char_100', _pd_str_dtype), ('binary_100', 'object'), - ('varchar_200', 'object'), + ('varchar_200', _pd_str_dtype), ('varbinary_200', 'object'), - ('longtext', 'object'), - ('mediumtext', 'object'), - ('text', 'object'), - ('tinytext', 'object'), + ('longtext', _pd_str_dtype), + ('mediumtext', _pd_str_dtype), + ('text', _pd_str_dtype), + ('tinytext', _pd_str_dtype), ('longblob', 'object'), ('mediumblob', 'object'), ('blob', 'object'), ('tinyblob', 'object'), ('json', 'object'), - ('enum', 'object'), - ('set', 'object'), + ('enum', _pd_str_dtype), + ('set', _pd_str_dtype), ('bit', 'object'), ] diff --git a/singlestoredb/utils/_lazy_import.py b/singlestoredb/utils/_lazy_import.py new file mode 100644 index 000000000..7bc532546 --- /dev/null +++ b/singlestoredb/utils/_lazy_import.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +"""Lazy import utilities for heavy optional dependencies.""" +import importlib +from functools import lru_cache +from typing import Any +from typing import Optional + + +@lru_cache(maxsize=None) +def get_numpy() -> Optional[Any]: + """Return numpy module or None if not installed.""" + try: + return importlib.import_module('numpy') + except ImportError: + return None + + +@lru_cache(maxsize=None) +def get_pandas() -> Optional[Any]: + """Return pandas module or None if not installed.""" + try: + return importlib.import_module('pandas') + except ImportError: + return None + + +@lru_cache(maxsize=None) +def get_polars() -> Optional[Any]: + """Return polars module or None if not installed.""" + try: + return importlib.import_module('polars') + except ImportError: + return None + + +@lru_cache(maxsize=None) +def get_pyarrow() -> Optional[Any]: + """Return pyarrow module or None if not installed.""" + try: + return importlib.import_module('pyarrow') + except ImportError: + return None diff --git a/singlestoredb/utils/dtypes.py b/singlestoredb/utils/dtypes.py index 73eb893c1..e343110df 100644 --- a/singlestoredb/utils/dtypes.py +++ b/singlestoredb/utils/dtypes.py @@ -1,22 +1,11 @@ #!/usr/bin/env python3 +from functools import lru_cache +from typing import Any +from typing import Dict -try: - import numpy as np - has_numpy = True -except ImportError: - has_numpy = False - -try: - import polars as pl - has_polars = True -except ImportError: - has_polars = False - -try: - import pyarrow as pa - has_pyarrow = True -except ImportError: - has_pyarrow = False +from ._lazy_import import get_numpy +from ._lazy_import import get_polars +from ._lazy_import import get_pyarrow DEFAULT_VALUES = { @@ -64,8 +53,13 @@ } -if has_numpy: - NUMPY_TYPE_MAP = { +@lru_cache(maxsize=None) +def get_numpy_type_map() -> Dict[int, Any]: + """Return numpy type map, or empty dict if numpy is not installed.""" + np = get_numpy() + if np is None: + return {} + return { 0: object, # Decimal 1: np.int8, # Tiny -1: np.uint8, # Unsigned Tiny @@ -107,13 +101,15 @@ -254: object, # Binary 255: object, # Geometry } -else: - NUMPY_TYPE_MAP = {} -PANDAS_TYPE_MAP = NUMPY_TYPE_MAP -if has_pyarrow: - PYARROW_TYPE_MAP = { +@lru_cache(maxsize=None) +def get_pyarrow_type_map() -> Dict[int, Any]: + """Return pyarrow type map, or empty dict if pyarrow is not installed.""" + pa = get_pyarrow() + if pa is None: + return {} + return { 0: pa.decimal128(18, 6), # Decimal 1: pa.int8(), # Tiny -1: pa.uint8(), # Unsigned Tiny @@ -155,11 +151,15 @@ -254: pa.binary(), # Binary 255: pa.string(), # Geometry } -else: - PYARROW_TYPE_MAP = {} -if has_polars: - POLARS_TYPE_MAP = { + +@lru_cache(maxsize=None) +def get_polars_type_map() -> Dict[int, Any]: + """Return polars type map, or empty dict if polars is not installed.""" + pl = get_polars() + if pl is None: + return {} + return { 0: pl.Decimal(10, 6), # Decimal 1: pl.Int8, # Tiny -1: pl.UInt8, # Unsigned Tiny @@ -201,5 +201,3 @@ -254: pl.Binary, # Binary 255: pl.Utf8, # Geometry } -else: - POLARS_TYPE_MAP = {} diff --git a/singlestoredb/utils/results.py b/singlestoredb/utils/results.py index 838465714..3264bd5f2 100644 --- a/singlestoredb/utils/results.py +++ b/singlestoredb/utils/results.py @@ -2,6 +2,7 @@ """SingleStoreDB package utilities.""" import collections import warnings +from functools import lru_cache from typing import Any from typing import Callable from typing import Dict @@ -9,47 +10,34 @@ from typing import NamedTuple from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING from typing import Union -from .dtypes import NUMPY_TYPE_MAP -from .dtypes import POLARS_TYPE_MAP -from .dtypes import PYARROW_TYPE_MAP +if TYPE_CHECKING: + import numpy + import pandas + import polars + import pyarrow + +from ._lazy_import import get_numpy +from ._lazy_import import get_pandas +from ._lazy_import import get_polars +from ._lazy_import import get_pyarrow +from .dtypes import get_numpy_type_map +from .dtypes import get_polars_type_map +from .dtypes import get_pyarrow_type_map UNSIGNED_FLAG = 32 BINARY_FLAG = 128 -try: - has_numpy = True - import numpy as np -except ImportError: - has_numpy = False - -try: - has_pandas = True - import pandas as pd -except ImportError: - has_pandas = False - -try: - has_polars = True - import polars as pl -except ImportError: - has_polars = False - -try: - has_pyarrow = True - import pyarrow as pa -except ImportError: - has_pyarrow = False - DBAPIResult = Union[List[Tuple[Any, ...]], Tuple[Any, ...]] OneResult = Union[ Tuple[Any, ...], Dict[str, Any], - 'np.ndarray', 'pd.DataFrame', 'pl.DataFrame', 'pa.Table', + 'numpy.ndarray', 'pandas.DataFrame', 'polars.DataFrame', 'pyarrow.Table', ] ManyResult = Union[ List[Tuple[Any, ...]], List[Dict[str, Any]], - 'np.ndarray', 'pd.DataFrame', 'pl.DataFrame', 'pa.Table', + 'numpy.ndarray', 'pandas.DataFrame', 'polars.DataFrame', 'pyarrow.Table', ] Result = Union[OneResult, ManyResult] @@ -67,11 +55,14 @@ class Description(NamedTuple): charset: Optional[int] -if has_numpy: - # If an int column is nullable, we need to use floats rather than - # ints for numpy and pandas. - NUMPY_TYPE_MAP_CAST_FLOAT = NUMPY_TYPE_MAP.copy() - NUMPY_TYPE_MAP_CAST_FLOAT.update({ +@lru_cache(maxsize=None) +def _get_numpy_type_map_cast_float() -> Dict[int, Any]: + """Return numpy type map with int types cast to float for nullable columns.""" + np = get_numpy() + if np is None: + return {} + type_map = get_numpy_type_map().copy() + type_map.update({ 1: np.float32, # Tiny -1: np.float32, # Unsigned Tiny 2: np.float32, # Short @@ -84,15 +75,23 @@ class Description(NamedTuple): -9: np.float64, # Unsigned Int24 13: np.float64, # Year }) + return type_map -if has_polars: + +@lru_cache(maxsize=None) +def _get_polars_type_map_with_dates() -> Dict[int, Any]: + """Return polars type map with date/times remapped to strings.""" + pl = get_polars() + if pl is None: + return {} + type_map = get_polars_type_map().copy() # Remap date/times to strings; let polars do the parsing - POLARS_TYPE_MAP = POLARS_TYPE_MAP.copy() - POLARS_TYPE_MAP.update({ + type_map.update({ 7: pl.Utf8, 10: pl.Utf8, 12: pl.Utf8, }) + return type_map INT_TYPES = set([1, 2, 3, 8, 9]) @@ -109,13 +108,15 @@ def signed(desc: Description) -> int: def _description_to_numpy_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to numpy array schema info.""" - if has_numpy: + if get_numpy() is not None: + numpy_type_map = get_numpy_type_map() + numpy_type_map_cast_float = _get_numpy_type_map_cast_float() return dict( dtype=[ ( x.name, - NUMPY_TYPE_MAP_CAST_FLOAT[signed(x)] - if x.null_ok else NUMPY_TYPE_MAP[signed(x)], + numpy_type_map_cast_float[signed(x)] + if x.null_ok else numpy_type_map[signed(x)], ) for x in desc ], @@ -125,18 +126,21 @@ def _description_to_numpy_schema(desc: List[Description]) -> Dict[str, Any]: def _description_to_pandas_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to pandas DataFrame schema info.""" - if has_pandas: + if get_pandas() is not None: return dict(columns=[x.name for x in desc]) return {} -def _decimalize_polars(desc: Description) -> 'pl.Decimal': - return pl.Decimal(desc.precision or 10, desc.scale or 0) +def _decimalize_polars(desc: Description) -> Any: + pl = get_polars() + return pl.Decimal(desc.precision or 10, desc.scale or 0) # type: ignore[union-attr] def _description_to_polars_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to polars DataFrame schema info.""" - if has_polars: + pl = get_polars() + if pl is not None: + polars_type_map = _get_polars_type_map_with_dates() with_columns = {} for x in desc: if x.type_code in [7, 12]: @@ -156,7 +160,8 @@ def _description_to_polars_schema(desc: List[Description]) -> Dict[str, Any]: schema=[ ( x.name, _decimalize_polars(x) - if x.type_code in DECIMAL_TYPES else POLARS_TYPE_MAP[signed(x)], + if x.type_code in DECIMAL_TYPES + else polars_type_map[signed(x)], ) for x in desc ], @@ -166,18 +171,24 @@ def _description_to_polars_schema(desc: List[Description]) -> Dict[str, Any]: return {} -def _decimalize_arrow(desc: Description) -> 'pa.Decimal128': - return pa.decimal128(desc.precision or 10, desc.scale or 0) +def _decimalize_arrow(desc: Description) -> Any: + pa = get_pyarrow() + return pa.decimal128( # type: ignore[union-attr] + desc.precision or 10, desc.scale or 0, + ) def _description_to_arrow_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to Arrow Table schema info.""" - if has_pyarrow: + pa = get_pyarrow() + if pa is not None: + pyarrow_type_map = get_pyarrow_type_map() return dict( schema=pa.schema([ ( x.name, _decimalize_arrow(x) - if x.type_code in DECIMAL_TYPES else PYARROW_TYPE_MAP[signed(x)], + if x.type_code in DECIMAL_TYPES + else pyarrow_type_map[signed(x)], ) for x in desc ]), @@ -215,7 +226,8 @@ def results_to_numpy( """ if not res: return res - if has_numpy: + np = get_numpy() + if np is not None: schema = _description_to_numpy_schema(desc) if schema is None else schema if single: return np.array([res], **schema) @@ -257,7 +269,8 @@ def results_to_pandas( """ if not res: return res - if has_pandas: + pd = get_pandas() + if pd is not None: schema = _description_to_pandas_schema(desc) if schema is None else schema return pd.DataFrame(results_to_numpy(desc, res, single=single, schema=schema)) warnings.warn( @@ -297,7 +310,8 @@ def results_to_polars( """ if not res: return res - if has_polars: + pl = get_polars() + if pl is not None: schema = _description_to_polars_schema(desc) if schema is None else schema if single: out = pl.DataFrame([res], orient='row', **schema.get('schema', {})) @@ -344,7 +358,8 @@ def results_to_arrow( """ if not res: return res - if has_pyarrow: + pa = get_pyarrow() + if pa is not None: names = [x[0] for x in desc] schema = _description_to_arrow_schema(desc) if schema is None else schema if single: From 24df239c9d2d9abe4d24e64f474dfc8b6b6b978d Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 23 Mar 2026 13:00:55 -0500 Subject: [PATCH 06/19] Fix Python 3.10+ union syntax in udf_handler type annotation Replace `str | None` with `Optional[str]` to maintain compatibility with Python 3.9 and earlier. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/wasm/udf_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/singlestoredb/functions/ext/wasm/udf_handler.py b/singlestoredb/functions/ext/wasm/udf_handler.py index eb4f13feb..fb6f66946 100644 --- a/singlestoredb/functions/ext/wasm/udf_handler.py +++ b/singlestoredb/functions/ext/wasm/udf_handler.py @@ -22,6 +22,7 @@ from typing import Callable from typing import Dict from typing import List +from typing import Optional # Install numpy stub before importing singlestoredb (which tries to import numpy) @@ -58,7 +59,7 @@ class _TracingFormatter(logging.Formatter): 'CRITICAL': '\033[31m', # red } - def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str: + def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: from datetime import datetime, timezone dt = datetime.fromtimestamp(record.created, tz=timezone.utc) return dt.strftime('%Y-%m-%dT%H:%M:%S.') + f'{dt.microsecond:06d}Z' From 5eef5fd2c96c6b603b119cf266de5ce20c2530e4 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Mon, 23 Mar 2026 13:13:19 -0500 Subject: [PATCH 07/19] feat: add call_function_accel C function to accel.c Add the call_function_accel function directly to accel.c, implementing a combined load/call/dump operation for UDF function calls. This function handles rowdat_1 deserialization, Python UDF invocation, and result serialization in a single optimized C implementation. Previously this function was injected at build time via a patch script in the wasm-udf-server repository. Moving it into the source tree is a prerequisite for cleaning up the custom componentize-py builder and simplifying the WASM component build process. Co-Authored-By: Claude Opus 4.6 --- accel.c | 520 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 520 insertions(+) diff --git a/accel.c b/accel.c index ee9a72f3f..29d482bfc 100644 --- a/accel.c +++ b/accel.c @@ -4774,12 +4774,532 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) } +static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *kwargs) { + PyObject *py_colspec = NULL, *py_returns = NULL, *py_data = NULL, *py_func = NULL; + PyObject *py_out = NULL, *py_row = NULL, *py_result = NULL, *py_result_item = NULL; + PyObject *py_str = NULL, *py_blob = NULL, *py_bytes = NULL; + Py_ssize_t length = 0; + uint64_t row_id = 0; + uint8_t is_null = 0; + int8_t i8 = 0; int16_t i16 = 0; int32_t i32 = 0; int64_t i64 = 0; + uint8_t u8 = 0; uint16_t u16 = 0; uint32_t u32 = 0; uint64_t u64 = 0; + float flt = 0; double dbl = 0; + int *ctypes = NULL, *rtypes = NULL; + char *data = NULL, *end = NULL, *out = NULL; + unsigned long long out_l = 0, out_idx = 0, colspec_l = 0, returns_l = 0, i = 0; + char *keywords[] = {"colspec", "returns", "data", "func", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOO", keywords, + &py_colspec, &py_returns, &py_data, &py_func)) goto error; + if (!PyCallable_Check(py_func)) { + PyErr_SetString(PyExc_TypeError, "func must be callable"); goto error; + } + + CHECKRC(PyBytes_AsStringAndSize(py_data, &data, &length)); + end = data + (unsigned long long)length; + if (length == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; } + + // Parse colspec types + colspec_l = (unsigned long long)PyObject_Length(py_colspec); + ctypes = malloc(sizeof(int) * colspec_l); + if (!ctypes) goto error; + for (i = 0; i < colspec_l; i++) { + PyObject *py_cspec = PySequence_GetItem(py_colspec, i); + if (!py_cspec) goto error; + PyObject *py_ctype = PySequence_GetItem(py_cspec, 1); + if (!py_ctype) { Py_DECREF(py_cspec); goto error; } + ctypes[i] = (int)PyLong_AsLong(py_ctype); + Py_DECREF(py_ctype); Py_DECREF(py_cspec); + } + + // Parse return types + returns_l = (unsigned long long)PyObject_Length(py_returns); + rtypes = malloc(sizeof(int) * returns_l); + if (!rtypes) goto error; + for (i = 0; i < returns_l; i++) { + PyObject *py_item = PySequence_GetItem(py_returns, i); + if (!py_item) goto error; + rtypes[i] = (int)PyLong_AsLong(py_item); + Py_DECREF(py_item); + } + + out_l = 256; + out = malloc(out_l); + if (!out) goto error; + +#define CHECKMEM_CFA(x) \ + if ((out_idx + (x)) > out_l) { \ + out_l = out_l * 2 + (x); \ + char *new_out = realloc(out, out_l); \ + if (!new_out) { \ + PyErr_SetString(PyExc_MemoryError, "failed to reallocate output buffer"); \ + goto error; \ + } \ + out = new_out; \ + } + + // Main loop: parse input rows, call function, serialize output + while (end > data) { + py_row = PyTuple_New(colspec_l); + if (!py_row) goto error; + + // Read row ID + row_id = *(int64_t*)data; data += 8; + + // Parse input columns + for (i = 0; i < colspec_l; i++) { + is_null = data[0] == '\x01'; data += 1; + if (is_null) Py_INCREF(Py_None); + + switch (ctypes[i]) { + case MYSQL_TYPE_NULL: + data += 1; + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + break; + + case MYSQL_TYPE_TINY: + i8 = *(int8_t*)data; data += 1; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i8))); + } + break; + + case -MYSQL_TYPE_TINY: + u8 = *(uint8_t*)data; data += 1; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u8))); + } + break; + + case MYSQL_TYPE_SHORT: + i16 = *(int16_t*)data; data += 2; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i16))); + } + break; + + case -MYSQL_TYPE_SHORT: + u16 = *(uint16_t*)data; data += 2; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); + } + break; + + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + i32 = *(int32_t*)data; data += 4; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i32))); + } + break; + + case -MYSQL_TYPE_LONG: + case -MYSQL_TYPE_INT24: + u32 = *(uint32_t*)data; data += 4; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u32))); + } + break; + + case MYSQL_TYPE_LONGLONG: + i64 = *(int64_t*)data; data += 8; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLongLong((long long)i64))); + } + break; + + case -MYSQL_TYPE_LONGLONG: + u64 = *(uint64_t*)data; data += 8; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLongLong((unsigned long long)u64))); + } + break; + + case MYSQL_TYPE_FLOAT: + flt = *(float*)data; data += 4; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)flt))); + } + break; + + case MYSQL_TYPE_DOUBLE: + dbl = *(double*)data; data += 8; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)dbl))); + } + break; + + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_NEWDECIMAL: + // TODO + break; + + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_NEWDATE: + // TODO + break; + + case MYSQL_TYPE_TIME: + // TODO + break; + + case MYSQL_TYPE_DATETIME: + // TODO + break; + + case MYSQL_TYPE_TIMESTAMP: + // TODO + break; + + case MYSQL_TYPE_YEAR: + u16 = *(uint16_t*)data; data += 2; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); + } + break; + + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_JSON: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_ENUM: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_GEOMETRY: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + i64 = *(int64_t*)data; data += 8; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + py_str = PyUnicode_FromStringAndSize(data, (Py_ssize_t)i64); + data += i64; + if (!py_str) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_str)); + py_str = NULL; + } + break; + + case -MYSQL_TYPE_VARCHAR: + case -MYSQL_TYPE_JSON: + case -MYSQL_TYPE_SET: + case -MYSQL_TYPE_ENUM: + case -MYSQL_TYPE_VAR_STRING: + case -MYSQL_TYPE_STRING: + case -MYSQL_TYPE_GEOMETRY: + case -MYSQL_TYPE_TINY_BLOB: + case -MYSQL_TYPE_MEDIUM_BLOB: + case -MYSQL_TYPE_LONG_BLOB: + case -MYSQL_TYPE_BLOB: + i64 = *(int64_t*)data; data += 8; + if (is_null) { + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + Py_INCREF(Py_None); + } else { + py_blob = PyBytes_FromStringAndSize(data, (Py_ssize_t)i64); + data += i64; + if (!py_blob) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_blob)); + py_blob = NULL; + } + break; + + default: + goto error; + } + } + + // Call the user function + py_result = PyObject_Call(py_func, py_row, NULL); + Py_DECREF(py_row); + py_row = NULL; + if (!py_result) goto error; + + // Normalize result: wrap scalar in a tuple + if (!PyList_Check(py_result) && !PyTuple_Check(py_result)) { + PyObject *py_wrapped = PyTuple_Pack(1, py_result); + Py_DECREF(py_result); + py_result = py_wrapped; + if (!py_result) goto error; + } + + // Write row ID to output + CHECKMEM_CFA(8); + memcpy(out+out_idx, &row_id, 8); + out_idx += 8; + + // Serialize output columns + for (i = 0; i < returns_l; i++) { + py_result_item = PySequence_GetItem(py_result, i); + if (!py_result_item) goto error; + + is_null = (uint8_t)(py_result_item == Py_None); + + CHECKMEM_CFA(1); + memcpy(out+out_idx, &is_null, 1); + out_idx += 1; + + switch (rtypes[i]) { + case MYSQL_TYPE_BIT: + // TODO + break; + + case MYSQL_TYPE_TINY: + CHECKMEM_CFA(1); + i8 = (is_null) ? 0 : (int8_t)PyLong_AsLong(py_result_item); + memcpy(out+out_idx, &i8, 1); + out_idx += 1; + break; + + case -MYSQL_TYPE_TINY: + CHECKMEM_CFA(1); + u8 = (is_null) ? 0 : (uint8_t)PyLong_AsUnsignedLong(py_result_item); + memcpy(out+out_idx, &u8, 1); + out_idx += 1; + break; + + case MYSQL_TYPE_SHORT: + CHECKMEM_CFA(2); + i16 = (is_null) ? 0 : (int16_t)PyLong_AsLong(py_result_item); + memcpy(out+out_idx, &i16, 2); + out_idx += 2; + break; + + case -MYSQL_TYPE_SHORT: + CHECKMEM_CFA(2); + u16 = (is_null) ? 0 : (uint16_t)PyLong_AsUnsignedLong(py_result_item); + memcpy(out+out_idx, &u16, 2); + out_idx += 2; + break; + + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + CHECKMEM_CFA(4); + i32 = (is_null) ? 0 : (int32_t)PyLong_AsLong(py_result_item); + memcpy(out+out_idx, &i32, 4); + out_idx += 4; + break; + + case -MYSQL_TYPE_LONG: + case -MYSQL_TYPE_INT24: + CHECKMEM_CFA(4); + u32 = (is_null) ? 0 : (uint32_t)PyLong_AsUnsignedLong(py_result_item); + memcpy(out+out_idx, &u32, 4); + out_idx += 4; + break; + + case MYSQL_TYPE_LONGLONG: + CHECKMEM_CFA(8); + i64 = (is_null) ? 0 : (int64_t)PyLong_AsLongLong(py_result_item); + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + break; + + case -MYSQL_TYPE_LONGLONG: + CHECKMEM_CFA(8); + u64 = (is_null) ? 0 : (uint64_t)PyLong_AsUnsignedLongLong(py_result_item); + memcpy(out+out_idx, &u64, 8); + out_idx += 8; + break; + + case MYSQL_TYPE_FLOAT: + CHECKMEM_CFA(4); + flt = (is_null) ? 0 : (float)PyFloat_AsDouble(py_result_item); + memcpy(out+out_idx, &flt, 4); + out_idx += 4; + break; + + case MYSQL_TYPE_DOUBLE: + CHECKMEM_CFA(8); + dbl = (is_null) ? 0 : (double)PyFloat_AsDouble(py_result_item); + memcpy(out+out_idx, &dbl, 8); + out_idx += 8; + break; + + case MYSQL_TYPE_DECIMAL: + // TODO + break; + + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_NEWDATE: + // TODO + break; + + case MYSQL_TYPE_TIME: + // TODO + break; + + case MYSQL_TYPE_DATETIME: + // TODO + break; + + case MYSQL_TYPE_TIMESTAMP: + // TODO + break; + + case MYSQL_TYPE_YEAR: + CHECKMEM_CFA(2); + i16 = (is_null) ? 0 : (int16_t)PyLong_AsLong(py_result_item); + memcpy(out+out_idx, &i16, 2); + out_idx += 2; + break; + + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_JSON: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_ENUM: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_GEOMETRY: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + if (is_null) { + CHECKMEM_CFA(8); + i64 = 0; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + } else { + py_bytes = PyUnicode_AsEncodedString(py_result_item, "utf-8", "strict"); + if (!py_bytes) { + Py_DECREF(py_result_item); + goto error; + } + + char *str = NULL; + Py_ssize_t str_l = 0; + if (PyBytes_AsStringAndSize(py_bytes, &str, &str_l) < 0) { + Py_DECREF(py_bytes); + Py_DECREF(py_result_item); + goto error; + } + + CHECKMEM_CFA(8+str_l); + i64 = str_l; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + memcpy(out+out_idx, str, str_l); + out_idx += str_l; + Py_DECREF(py_bytes); + py_bytes = NULL; + } + break; + + case -MYSQL_TYPE_VARCHAR: + case -MYSQL_TYPE_JSON: + case -MYSQL_TYPE_SET: + case -MYSQL_TYPE_ENUM: + case -MYSQL_TYPE_VAR_STRING: + case -MYSQL_TYPE_STRING: + case -MYSQL_TYPE_GEOMETRY: + case -MYSQL_TYPE_TINY_BLOB: + case -MYSQL_TYPE_MEDIUM_BLOB: + case -MYSQL_TYPE_LONG_BLOB: + case -MYSQL_TYPE_BLOB: + if (is_null) { + CHECKMEM_CFA(8); + i64 = 0; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + } else { + char *str = NULL; + Py_ssize_t str_l = 0; + if (PyBytes_AsStringAndSize(py_result_item, &str, &str_l) < 0) { + Py_DECREF(py_result_item); + goto error; + } + + CHECKMEM_CFA(8+str_l); + i64 = str_l; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + memcpy(out+out_idx, str, str_l); + out_idx += str_l; + } + break; + + default: + Py_DECREF(py_result_item); + goto error; + } + + Py_DECREF(py_result_item); + py_result_item = NULL; + } + + Py_DECREF(py_result); + py_result = NULL; + } + +#undef CHECKMEM_CFA + + py_out = PyBytes_FromStringAndSize(out, out_idx); + +exit: + if (out) free(out); + if (ctypes) free(ctypes); + if (rtypes) free(rtypes); + + Py_XDECREF(py_row); + Py_XDECREF(py_result); + Py_XDECREF(py_result_item); + Py_XDECREF(py_str); + Py_XDECREF(py_blob); + Py_XDECREF(py_bytes); + + return py_out; + +error: + Py_XDECREF(py_out); + py_out = NULL; + + goto exit; +} + static PyMethodDef PyMySQLAccelMethods[] = { {"read_rowdata_packet", (PyCFunction)read_rowdata_packet, METH_VARARGS | METH_KEYWORDS, "PyMySQL row data packet reader"}, {"dump_rowdat_1", (PyCFunction)dump_rowdat_1, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions"}, {"load_rowdat_1", (PyCFunction)load_rowdat_1, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions"}, {"dump_rowdat_1_numpy", (PyCFunction)dump_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions which takes numpy.arrays"}, {"load_rowdat_1_numpy", (PyCFunction)load_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions which creates numpy.arrays"}, + {"call_function_accel", (PyCFunction)call_function_accel, METH_VARARGS | METH_KEYWORDS, "Combined load/call/dump for UDF function calls"}, {NULL, NULL, 0, NULL} }; From 0e4bb40f93c960973e0b3104d8c0d9e5d4444e2a Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 25 Mar 2026 09:15:24 -0500 Subject: [PATCH 08/19] Add WASM build script for wasm32-wasip2 target Add resources/build_wasm.sh that cross-compiles the package as a WASM wheel targeting wasm32-wasip2. The script sets up a host venv, configures the WASI SDK toolchain (clang, ar, linker flags), and uses `python -m build` to produce the wheel, then unpacks it into build/. Co-Authored-By: Claude Opus 4.6 --- resources/build_wasm.sh | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100755 resources/build_wasm.sh diff --git a/resources/build_wasm.sh b/resources/build_wasm.sh new file mode 100755 index 000000000..820acca04 --- /dev/null +++ b/resources/build_wasm.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -eou pipefail + +# CPYTHON_ROOT must contain a build of cpython for wasm32-wasip2 + +TARGET="wasm32-wasip2" +CROSS_BUILD="${CPYTHON_ROOT}/cross-build/${TARGET}" +WASI_SDK_PATH=${WASI_SDK_PATH:-/opt/wasi-sdk} +PYTHON_VERSION=$(grep '^VERSION=' "${CROSS_BUILD}/Makefile" | sed 's/VERSION=[[:space:]]*//') + +if [ ! -e wasm_venv ]; then + uv venv --python ${PYTHON_VERSION} wasm_venv +fi + +. wasm_venv/bin/activate + +HOST_PYTHON=$(which python3) + +uv pip install build wheel cython setuptools + +ARCH_TRIPLET=_wasi_wasm32-wasi + +export CC="${WASI_SDK_PATH}/bin/clang" +export CXX="${WASI_SDK_PATH}/bin/clang++" + +export PYTHONPATH="${CROSS_BUILD}/build/lib.wasi-wasm32-${PYTHON_VERSION}" + +export CFLAGS="--target=${TARGET} -fPIC -I${CROSS_BUILD}/install/include/python${PYTHON_VERSION} -D__EMSCRIPTEN__=1" +export CXXFLAGS="--target=${TARGET} -fPIC -I${CROSS_BUILD}/install/include/python${PYTHON_VERSION}" +export LDSHARED=${CC} +export AR="${WASI_SDK_PATH}/bin/ar" +export RANLIB=true +export LDFLAGS="--target=${TARGET} -shared -Wl,--allow-undefined" +export _PYTHON_SYSCONFIGDATA_NAME=_sysconfigdata__wasi_wasm32-wasi +export _PYTHON_HOST_PLATFORM=wasm32-wasi + +python3 -m build -n -w +wheel unpack --dest build dist/*.whl + +rm -rf ./wasm_venv From a18eb52cbf430dbf9913f215820ddf16aa2acdec Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 25 Mar 2026 15:34:00 -0500 Subject: [PATCH 09/19] Remove WASM numpy_stub, now unnecessary with lazy imports numpy is lazy-loaded throughout the codebase via the _lazy_import helpers, so the WASM numpy_stub that patched sys.modules['numpy'] is no longer needed. Delete the stub module and remove its references from udf_handler.py. Co-Authored-By: Claude Opus 4.6 --- .../functions/ext/wasm/numpy_stub.py | 127 ------------------ .../functions/ext/wasm/udf_handler.py | 15 +-- 2 files changed, 3 insertions(+), 139 deletions(-) delete mode 100644 singlestoredb/functions/ext/wasm/numpy_stub.py diff --git a/singlestoredb/functions/ext/wasm/numpy_stub.py b/singlestoredb/functions/ext/wasm/numpy_stub.py deleted file mode 100644 index 386c4a90d..000000000 --- a/singlestoredb/functions/ext/wasm/numpy_stub.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Minimal numpy stub for WASM environment. - -This provides just enough of numpy's interface for singlestoredb's -get_signature() function to work without actual numpy. -""" - - -class _DtypeMeta(type): - """Metaclass for dtype to support typing.get_origin checks.""" - pass - - -class dtype(metaclass=_DtypeMeta): - """Stub dtype class.""" - def __init__(self, spec: object = None) -> None: - self.spec = spec - - def __repr__(self) -> str: - return f'dtype({self.spec!r})' - - -class ndarray: - """Stub ndarray class.""" - pass - - -# Stub type classes - these just need to exist for isinstance/issubclass checks -class bool_: - pass - - -class integer: - pass - - -class int_(integer): - pass - - -class int8(integer): - pass - - -class int16(integer): - pass - - -class int32(integer): - pass - - -class int64(integer): - pass - - -class uint(integer): - pass - - -class unsignedinteger(integer): - pass - - -class uint8(unsignedinteger): - pass - - -class uint16(unsignedinteger): - pass - - -class uint32(unsignedinteger): - pass - - -class uint64(unsignedinteger): - pass - - -class longlong(integer): - pass - - -class ulonglong(unsignedinteger): - pass - - -class str_: - pass - - -class bytes_: - pass - - -class float16: - pass - - -class float32: - pass - - -class float64: - pass - - -class double(float64): - pass - - -class float_(float64): - pass - - -class single(float32): - pass - - -class unicode_(str_): - pass - - -# Common aliases -float = float64 -int = int64 diff --git a/singlestoredb/functions/ext/wasm/udf_handler.py b/singlestoredb/functions/ext/wasm/udf_handler.py index fb6f66946..e59b08371 100644 --- a/singlestoredb/functions/ext/wasm/udf_handler.py +++ b/singlestoredb/functions/ext/wasm/udf_handler.py @@ -24,18 +24,9 @@ from typing import List from typing import Optional - -# Install numpy stub before importing singlestoredb (which tries to import numpy) -if 'numpy' not in sys.modules: - try: - import numpy # noqa: F401 - except ImportError: - from . import numpy_stub - sys.modules['numpy'] = numpy_stub - -from singlestoredb.functions.signature import get_signature -from singlestoredb.functions.ext.rowdat_1 import load as _load_rowdat_1 from singlestoredb.functions.ext.rowdat_1 import dump as _dump_rowdat_1 +from singlestoredb.functions.ext.rowdat_1 import load as _load_rowdat_1 +from singlestoredb.functions.signature import get_signature from singlestoredb.mysql.constants import FIELD_TYPE as ft try: @@ -144,7 +135,7 @@ def _is_stdlib_or_infra(mod_name: str, mod_file: str) -> bool: """ # Infrastructure modules that are part of this project _infra = frozenset({ - 'udf_handler', 'numpy_stub', + 'udf_handler', }) if mod_name in _infra: return True From 248f8971e67f0705001aaf965033bcc2d6c867f7 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Fri, 27 Mar 2026 12:43:54 -0500 Subject: [PATCH 10/19] Add collocated Python UDF server with pre-fork process mode Add a standalone collocated UDF server package that can run as a drop-in replacement for the Rust wasm-udf-server. Uses pre-fork worker processes (default) for true CPU parallelism, avoiding GIL contention in the C-accelerated call path. Thread pool mode is available via --process-mode thread. Collapse the wasm subpackage into a single wasm.py module since it only contained one class re-exported through __init__.py. Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 3 + .../functions/ext/collocated/__init__.py | 1 + .../functions/ext/collocated/__main__.py | 132 +++++++ .../functions/ext/collocated/connection.py | 250 ++++++++++++ .../functions/ext/collocated/control.py | 115 ++++++ .../udf_handler.py => collocated/registry.py} | 238 ++++-------- .../functions/ext/collocated/server.py | 365 ++++++++++++++++++ .../functions/ext/collocated/wasm.py | 60 +++ singlestoredb/functions/ext/wasm/__init__.py | 0 9 files changed, 1006 insertions(+), 158 deletions(-) create mode 100644 singlestoredb/functions/ext/collocated/__init__.py create mode 100644 singlestoredb/functions/ext/collocated/__main__.py create mode 100644 singlestoredb/functions/ext/collocated/connection.py create mode 100644 singlestoredb/functions/ext/collocated/control.py rename singlestoredb/functions/ext/{wasm/udf_handler.py => collocated/registry.py} (65%) create mode 100644 singlestoredb/functions/ext/collocated/server.py create mode 100644 singlestoredb/functions/ext/collocated/wasm.py delete mode 100644 singlestoredb/functions/ext/wasm/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 8a8e91fd4..c910bce9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,9 @@ dev = [ "singlestoredb[test,docs,build]", ] +[project.scripts] +python-udf-server = "singlestoredb.functions.ext.collocated.__main__:main" + [project.entry-points.pytest11] singlestoredb = "singlestoredb.pytest" diff --git a/singlestoredb/functions/ext/collocated/__init__.py b/singlestoredb/functions/ext/collocated/__init__.py new file mode 100644 index 000000000..4b3400310 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/__init__.py @@ -0,0 +1 @@ +"""High-performance collocated Python UDF server for SingleStoreDB.""" diff --git a/singlestoredb/functions/ext/collocated/__main__.py b/singlestoredb/functions/ext/collocated/__main__.py new file mode 100644 index 000000000..402050a47 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/__main__.py @@ -0,0 +1,132 @@ +""" +CLI entry point for the collocated Python UDF server. + +Usage:: + + python -m singlestoredb.functions.ext.collocated \\ + --extension myfuncs \\ + --extension-path /home/user/libs \\ + --socket /tmp/my-udf.sock + +Arguments match the Rust wasm-udf-server CLI for drop-in compatibility. +""" +import argparse +import logging +import os +import secrets +import sys +import tempfile +from typing import Any + +from .registry import setup_logging +from .server import Server + +logger = logging.getLogger('collocated') + + +def main(argv: Any = None) -> None: + parser = argparse.ArgumentParser( + prog='python -m singlestoredb.functions.ext.collocated', + description='High-performance collocated Python UDF server', + ) + parser.add_argument( + '--extension', + default=os.environ.get('EXTERNAL_UDF_EXTENSION', ''), + help=( + 'Python module to import (e.g. myfuncs). ' + 'Env: EXTERNAL_UDF_EXTENSION' + ), + ) + parser.add_argument( + '--extension-path', + default=os.environ.get('EXTERNAL_UDF_EXTENSION_PATH', ''), + help=( + 'Colon-separated search dirs for the module. ' + 'Env: EXTERNAL_UDF_EXTENSION_PATH' + ), + ) + parser.add_argument( + '--socket', + default=os.environ.get( + 'EXTERNAL_UDF_SOCKET_PATH', + os.path.join( + tempfile.gettempdir(), + f'singlestore-udf-{os.getpid()}-{secrets.token_hex(4)}.sock', + ), + ), + help=( + 'Unix socket path. ' + 'Env: EXTERNAL_UDF_SOCKET_PATH' + ), + ) + parser.add_argument( + '--n-workers', + type=int, + default=int(os.environ.get('EXTERNAL_UDF_N_WORKERS', '0')), + help=( + 'Worker threads (0 = CPU count). ' + 'Env: EXTERNAL_UDF_N_WORKERS' + ), + ) + parser.add_argument( + '--max-connections', + type=int, + default=int(os.environ.get('EXTERNAL_UDF_MAX_CONNECTIONS', '32')), + help=( + 'Socket backlog. ' + 'Env: EXTERNAL_UDF_MAX_CONNECTIONS' + ), + ) + parser.add_argument( + '--log-level', + default=os.environ.get('EXTERNAL_UDF_LOG_LEVEL', 'info'), + choices=['debug', 'info', 'warning', 'error'], + help=( + 'Logging level. ' + 'Env: EXTERNAL_UDF_LOG_LEVEL' + ), + ) + parser.add_argument( + '--process-mode', + default=os.environ.get('EXTERNAL_UDF_PROCESS_MODE', 'process'), + choices=['thread', 'process'], + help=( + 'Concurrency mode: "thread" uses a thread pool, ' + '"process" uses pre-fork workers for true CPU ' + 'parallelism. Env: EXTERNAL_UDF_PROCESS_MODE' + ), + ) + + args = parser.parse_args(argv) + + if not args.extension: + parser.error( + '--extension is required ' + '(or set EXTERNAL_UDF_EXTENSION env var)', + ) + + # Setup logging + level = getattr(logging, args.log_level.upper()) + setup_logging(level) + + config = { + 'extension': args.extension, + 'extension_path': args.extension_path, + 'socket': args.socket, + 'n_workers': args.n_workers, + 'max_connections': args.max_connections, + 'process_mode': args.process_mode, + } + + server = Server(config) + try: + server.run() + except RuntimeError as exc: + logger.error(str(exc)) + sys.exit(1) + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/singlestoredb/functions/ext/collocated/connection.py b/singlestoredb/functions/ext/collocated/connection.py new file mode 100644 index 000000000..47921485d --- /dev/null +++ b/singlestoredb/functions/ext/collocated/connection.py @@ -0,0 +1,250 @@ +""" +Connection handler: protocol, mmap I/O, request loop. + +Implements the binary socket protocol matching the Rust wasm-udf-server: +handshake, control signal dispatch, and UDF request loop with mmap I/O. +""" +from __future__ import annotations + +import array +import logging +import mmap +import os +import select +import socket +import struct +import threading +import traceback +from typing import TYPE_CHECKING + +from .control import dispatch_control_signal +from .registry import call_function + +if TYPE_CHECKING: + from .server import SharedRegistry + +logger = logging.getLogger('collocated.connection') + +# Protocol constants +PROTOCOL_VERSION = 1 +STATUS_OK = 200 +STATUS_BAD_REQUEST = 400 +STATUS_ERROR = 500 + +# Minimum output mmap size to avoid repeated ftruncate +_MIN_OUTPUT_SIZE = 128 * 1024 + + +def handle_connection( + conn: socket.socket, + shared_registry: SharedRegistry, + shutdown_event: threading.Event, +) -> None: + """Handle a single client connection (runs in a thread pool worker).""" + try: + _handle_connection_inner(conn, shared_registry, shutdown_event) + except Exception: + logger.error(f'Connection error:\n{traceback.format_exc()}') + finally: + try: + conn.close() + except OSError: + pass + + +def _handle_connection_inner( + conn: socket.socket, + shared_registry: SharedRegistry, + shutdown_event: threading.Event, +) -> None: + """Inner connection handler (may raise).""" + # --- Handshake --- + # Receive 16 bytes: [version: u64 LE][namelen: u64 LE] + header = _recv_exact(conn, 16) + if header is None: + return + version, namelen = struct.unpack(' None: + """Handle a @@-prefixed control signal (one-shot request-response).""" + try: + # Read 8-byte request length + len_buf = _recv_exact(conn, 8) + if len_buf is None: + return + length = struct.unpack(' 0: + mem = mmap.mmap( + input_fd, length, mmap.MAP_SHARED, mmap.PROT_READ, + ) + try: + request_data = mem[:length] + finally: + mem.close() + + # Dispatch + result = dispatch_control_signal( + signal_name, request_data, shared_registry, + ) + + if result.ok: + # Write response to output mmap + response_bytes = result.data.encode('utf8') + response_size = len(response_bytes) + os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size)) + os.lseek(output_fd, 0, os.SEEK_SET) + os.write(output_fd, response_bytes) + + # Send [status=200, size] + conn.sendall(struct.pack(' None: + """Handle the UDF request loop for a single function.""" + # Track output mmap size to avoid repeated ftruncate + current_output_size = 0 + + try: + # Get thread-local registry + registry = shared_registry.get_thread_local_registry() + + while not shutdown_event.is_set(): + # Select-based recv with 100ms timeout for shutdown checks + readable, _, _ = select.select([conn], [], [], 0.1) + if not readable: + continue + + # Read 8-byte request length + len_buf = _recv_exact(conn, 8) + if len_buf is None: + break + length = struct.unpack(' current_output_size: + os.ftruncate(output_fd, needed) + current_output_size = needed + os.lseek(output_fd, 0, os.SEEK_SET) + os.write(output_fd, output_data) + + # Send [status=200, size] + conn.sendall(struct.pack(' bytes | None: + """Receive exactly n bytes, or return None on EOF.""" + buf = bytearray() + while len(buf) < n: + chunk = sock.recv(n - len(buf)) + if not chunk: + return None + buf.extend(chunk) + return bytes(buf) diff --git a/singlestoredb/functions/ext/collocated/control.py b/singlestoredb/functions/ext/collocated/control.py new file mode 100644 index 000000000..128414217 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/control.py @@ -0,0 +1,115 @@ +""" +Control signal dispatch for @@health, @@functions, @@register. + +Matches the Rust wasm-udf-server's dispatch_control_signal behavior. +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .registry import describe_functions_json + +if TYPE_CHECKING: + from .server import SharedRegistry + +logger = logging.getLogger('collocated.control') + + +@dataclass +class ControlResult: + """Result of a control signal dispatch.""" + ok: bool + data: str # JSON response on success, error message on failure + + +def dispatch_control_signal( + signal_name: str, + request_data: bytes, + shared_registry: SharedRegistry, +) -> ControlResult: + """Dispatch a control signal to the appropriate handler.""" + try: + if signal_name == '@@health': + return _handle_health() + elif signal_name == '@@functions': + return _handle_functions(shared_registry) + elif signal_name == '@@register': + return _handle_register(request_data, shared_registry) + else: + return ControlResult( + ok=False, + data=f'Unknown control signal: {signal_name}', + ) + except Exception as e: + return ControlResult(ok=False, data=str(e)) + + +def _handle_health() -> ControlResult: + """Handle @@health: return status ok.""" + return ControlResult(ok=True, data='{"status":"ok"}') + + +def _handle_functions(shared_registry: SharedRegistry) -> ControlResult: + """Handle @@functions: return function descriptions.""" + registry = shared_registry.get_thread_local_registry() + json_str = describe_functions_json(registry) + return ControlResult(ok=True, data=f'{{"functions":{json_str}}}') + + +def _handle_register( + request_data: bytes, + shared_registry: SharedRegistry, +) -> ControlResult: + """Handle @@register: register a new function dynamically.""" + if not request_data: + return ControlResult(ok=False, data='Missing registration payload') + + try: + body = json.loads(request_data) + except json.JSONDecodeError as e: + return ControlResult(ok=False, data=f'Invalid JSON: {e}') + + function_name = body.get('function_name') + if not function_name: + return ControlResult( + ok=False, data='Missing required field: function_name', + ) + + args = body.get('args') + if not isinstance(args, list): + return ControlResult( + ok=False, data='Missing required field: args (must be an array)', + ) + + returns = body.get('returns') + if not isinstance(returns, list): + return ControlResult( + ok=False, + data='Missing required field: returns (must be an array)', + ) + + func_body = body.get('body') + if not func_body: + return ControlResult( + ok=False, data='Missing required field: body', + ) + + replace = body.get('replace', False) + + # Build signature JSON matching describe-functions schema + signature = json.dumps({ + 'name': function_name, + 'args': args, + 'returns': returns, + }) + + try: + shared_registry.create_function(signature, func_body, replace) + except Exception as e: + return ControlResult(ok=False, data=str(e)) + + logger.info(f"@@register: added function '{function_name}'") + return ControlResult(ok=True, data='{"status":"ok"}') diff --git a/singlestoredb/functions/ext/wasm/udf_handler.py b/singlestoredb/functions/ext/collocated/registry.py similarity index 65% rename from singlestoredb/functions/ext/wasm/udf_handler.py rename to singlestoredb/functions/ext/collocated/registry.py index e59b08371..f44c83be8 100644 --- a/singlestoredb/functions/ext/wasm/udf_handler.py +++ b/singlestoredb/functions/ext/collocated/registry.py @@ -1,16 +1,11 @@ """ -Python UDF handler implementing the WIT interface for WASM component. +Function registry for UDF discovery, registration, and invocation. -This module provides a Python runtime for UDF functions. When compiled -with componentize-py, it becomes a WASM component that can be loaded by -the Rust UDF server. - -Functions are discovered automatically by scanning sys.modules for -@udf-decorated functions. No _exports.py is needed — just import -FunctionHandler from this module in your UDF file and decorate -functions with @udf. +This module contains the core FunctionRegistry class (moved from +wasm/udf_handler.py) plus standalone call_function() and +describe_functions_json() helpers. Both the WASM handler and the +collocated server use these directly. """ -import difflib # noqa: F401 import inspect import json import logging @@ -18,11 +13,14 @@ import sys import traceback import types +from datetime import datetime +from datetime import timezone from typing import Any from typing import Callable from typing import Dict from typing import List from typing import Optional +from typing import Tuple from singlestoredb.functions.ext.rowdat_1 import dump as _dump_rowdat_1 from singlestoredb.functions.ext.rowdat_1 import load as _load_rowdat_1 @@ -50,8 +48,11 @@ class _TracingFormatter(logging.Formatter): 'CRITICAL': '\033[31m', # red } - def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: - from datetime import datetime, timezone + def formatTime( + self, + record: logging.LogRecord, + datefmt: Optional[str] = None, + ) -> str: dt = datetime.fromtimestamp(record.created, tz=timezone.utc) return dt.strftime('%Y-%m-%dT%H:%M:%S.') + f'{dt.microsecond:06d}Z' @@ -64,16 +65,18 @@ def format(self, record: logging.LogRecord) -> str: return f'{self._DIM}{ts}{self._RESET} {level} {name}: {msg}' -_handler = logging.StreamHandler() -_handler.setFormatter(_TracingFormatter()) -logging.basicConfig(level=logging.INFO, handlers=[_handler]) -logger = logging.getLogger('udf_handler') +def setup_logging(level: int = logging.INFO) -> None: + """Configure root logging with the tracing formatter.""" + handler = logging.StreamHandler() + handler.setFormatter(_TracingFormatter()) + logging.basicConfig(level=level, handlers=[handler]) + # Map dtype strings to rowdat_1 type codes for wire serialization. # rowdat_1 always uses 8-byte encoding for integers and doubles for floats, # so all int types collapse to LONGLONG and all float types to DOUBLE. # Uses negative values for unsigned ints / binary data. -rowdat_1_type_map = { +rowdat_1_type_map: Dict[str, int] = { 'bool': ft.LONGLONG, 'int8': ft.LONGLONG, 'int16': ft.LONGLONG, @@ -108,6 +111,8 @@ def format(self, record: logging.LogRecord) -> str: 'bytes': 'bytes', } +logger = logging.getLogger('udf_handler') + class FunctionRegistry: """Registry of discovered UDF functions.""" @@ -119,7 +124,7 @@ def initialize(self) -> None: """Initialize and discover UDF functions from loaded modules. Scans sys.modules for any module containing @udf-decorated - functions. No _exports.py is needed — modules just need to be + functions. No _exports.py is needed -- modules just need to be imported before initialize() is called (componentize-py captures them at build time). """ @@ -133,18 +138,15 @@ def _is_stdlib_or_infra(mod_name: str, mod_file: str) -> bool: (under sys.prefix but not in site-packages) rather than maintaining a hardcoded list of names. """ - # Infrastructure modules that are part of this project _infra = frozenset({ 'udf_handler', }) if mod_name in _infra: return True - # Resolve symlinks for reliable prefix comparison real_file = os.path.realpath(mod_file) real_prefix = os.path.realpath(sys.prefix) - # Modules under sys.prefix but NOT in site-packages are stdlib if real_file.startswith(real_prefix + os.sep): if 'site-packages' not in real_file: return True @@ -160,6 +162,9 @@ def _discover_udf_functions(self) -> None: Modules without a __file__ (built-in/frozen) and stdlib/ infrastructure modules are skipped automatically. """ + # Import here to avoid circular dependency at module level + from .wasm import FunctionHandler + found_modules = [] for mod_name, mod in list(sys.modules.items()): if mod is None: @@ -178,7 +183,6 @@ def _discover_udf_functions(self) -> None: ): continue - # Skip stdlib and infrastructure modules if self._is_stdlib_or_infra(mod_name, mod_file): continue @@ -201,11 +205,7 @@ def _discover_udf_functions(self) -> None: ) def _extract_functions(self, module: Any) -> None: - """Extract @udf-decorated functions from a module. - - Unlike module scanning, this does not filter by __module__ — - _exports.py may re-export functions defined in other modules. - """ + """Extract @udf-decorated functions from a module.""" for name, obj in inspect.getmembers(module): if name.startswith('_'): continue @@ -216,7 +216,6 @@ def _extract_functions(self, module: Any) -> None: if not inspect.isfunction(obj): continue - # Only register functions decorated with @udf if not hasattr(obj, '_singlestoredb_attrs'): continue @@ -225,17 +224,13 @@ def _extract_functions(self, module: Any) -> None: if sig and sig.get('args') is not None and sig.get('returns'): self._register_function(obj, name, sig) except (TypeError, ValueError): - # Skip functions that can't be introspected pass def _build_json_descriptions( self, func_names: List[str], ) -> List[Dict[str, Any]]: - """Build JSON-serializable descriptions for the given function names. - - Extracts metadata from the stored signature dict for each function. - """ + """Build JSON-serializable descriptions for the given function names.""" descriptions = [] for func_name in func_names: func_info = self.functions[func_name] @@ -269,10 +264,7 @@ def _build_json_descriptions( @staticmethod def _python_type_annotation(dtype: str) -> str: - """Convert a dtype string to a Python type annotation. - - Handles nullable types (trailing '?') by wrapping in Optional. - """ + """Convert a dtype string to a Python type annotation.""" nullable = dtype.endswith('?') base = dtype.rstrip('?') py_type = _dtype_to_python.get(base) @@ -287,27 +279,17 @@ def _build_python_code( sig: Dict[str, Any], body: str, ) -> str: - """Build a complete @udf-decorated Python function from signature and body. - - Args: - sig: Parsed signature dict with 'name', 'args', 'returns'. - body: The function body (e.g. "return x * 3"). - - Returns: - Complete Python source with imports and a @udf-decorated function. - """ + """Build a complete @udf-decorated Python function from sig + body.""" func_name = sig['name'] args = sig.get('args', []) returns = sig.get('returns', []) - # Build parameter list with type annotations params = [] for arg in args: ann = FunctionRegistry._python_type_annotation(arg['dtype']) params.append(f'{arg["name"]}: {ann}') params_str = ', '.join(params) - # Build return type annotation if len(returns) == 0: ret_ann = 'None' elif len(returns) == 1: @@ -321,7 +303,6 @@ def _build_python_code( ] ret_ann = f'Tuple[{", ".join(parts)}]' - # Indent body lines indented_body = '\n'.join( f' {line}' for line in body.splitlines() ) @@ -343,10 +324,6 @@ def create_function( ) -> List[str]: """Register a function from its signature and function body. - Constructs a complete @udf-decorated Python function from the - signature metadata and the raw function body, then compiles - and executes it. - Args: signature_json: JSON object matching the describe-functions element schema (must contain a 'name' field) @@ -355,10 +332,6 @@ def create_function( Returns: List of newly registered function names - - Raises: - SyntaxError: If the generated code has syntax errors - ValueError: If the function already exists and replace is False """ sig = json.loads(signature_json) func_name = sig.get('name') @@ -367,29 +340,20 @@ def create_function( 'signature JSON must contain a "name" field', ) - # Check for name collision when replace is False if not replace and func_name in self.functions: raise ValueError( f'Function "{func_name}" already exists ' f'(use replace=true to overwrite)', ) - # When replacing, remove the old entry so the new registration - # is detected as "new" by the before/after name comparison. if replace and func_name in self.functions: del self.functions[func_name] - # Build a complete @udf-decorated function from signature + body full_code = self._build_python_code(sig, code) - # Use __main__ as the module name for dynamically submitted functions name = '__main__' - - # Validate syntax compiled = compile(full_code, f'<{name}>', 'exec') - # Reuse existing module to avoid corrupting the componentize-py - # runtime state (replacing sys.modules['__main__'] traps WASM). if name in sys.modules: module = sys.modules[name] else: @@ -398,7 +362,6 @@ def create_function( sys.modules[name] = module exec(compiled, module.__dict__) # noqa: S102 - # Extract functions from the module before_names = set(self.functions.keys()) self._extract_functions(module) new_names = [k for k in self.functions if k not in before_names] @@ -421,30 +384,26 @@ def _register_function( func_name: str, sig: Dict[str, Any], ) -> None: - """Register a function under its bare name. - - All functions are registered as top-level names (no module prefix). - If a function with the same name already exists, the last - registration wins. - """ - # Use alias name from signature if available, otherwise use function name + """Register a function under its bare name.""" full_name = sig.get('name') or func_name - # Convert args to (name, type_code) tuples - arg_types = [] + arg_types: List[Tuple[str, int]] = [] for arg in sig['args']: dtype = arg['dtype'].replace('?', '') if dtype not in rowdat_1_type_map: - logger.warning(f"Skipping {full_name}: unsupported arg dtype '{dtype}'") + logger.warning( + f"Skipping {full_name}: unsupported arg dtype '{dtype}'", + ) return arg_types.append((arg['name'], rowdat_1_type_map[dtype])) - # Convert returns to type_code list - return_types = [] + return_types: List[int] = [] for ret in sig['returns']: dtype = ret['dtype'].replace('?', '') if dtype not in rowdat_1_type_map: - logger.warning(f'Skipping {full_name}: no type mapping for {dtype}') + logger.warning( + f'Skipping {full_name}: no type mapping for {dtype}', + ) return return_types.append(rowdat_1_type_map[dtype]) @@ -456,86 +415,49 @@ def _register_function( } -# Global registry instance -_registry = FunctionRegistry() +def call_function( + registry: FunctionRegistry, + name: str, + input_data: bytes, +) -> bytes: + """Call a registered UDF by name using the C accelerator or fallback. + This is the hot-path function used by both the WASM handler and + the collocated server. + """ + if name not in registry.functions: + raise ValueError(f'unknown function: {name}') -class FunctionHandler: - """Implementation of the singlestore:udf/function-handler interface.""" + func_info = registry.functions[name] + func = func_info['func'] + arg_types = func_info['arg_types'] + return_types = func_info['return_types'] - def initialize(self) -> None: - """Initialize and discover UDF functions from loaded modules.""" + try: if _has_call_accel: - logger.info('Using accelerated C call_function_accel loop') - else: - logger.info('Using pure Python call_function loop') - _registry.initialize() - - def call_function(self, name: str, input_data: bytes) -> bytes: - """Call a function by its registered name.""" - if name not in _registry.functions: - raise ValueError(f'unknown function: {name}') - - func_info = _registry.functions[name] - func = func_info['func'] - arg_types = func_info['arg_types'] - return_types = func_info['return_types'] - - try: - if _has_call_accel: - return _call_function_accel( - colspec=arg_types, - returns=return_types, - data=input_data, - func=func, - ) - - # Fallback to pure Python - row_ids, rows = _load_rowdat_1(arg_types, input_data) - results = [] - for row in rows: - result = func(*row) - if not isinstance(result, tuple): - result = [result] - results.append(list(result)) - return bytes(_dump_rowdat_1(return_types, row_ids, results)) - - except Exception as e: - tb = traceback.format_exc() - raise RuntimeError(f'Error calling {name}: {e}\n{tb}') - - def describe_functions(self) -> str: - """Describe all functions as a JSON array. - - Returns a JSON string containing an array of function description - objects with: name, args, returns, args_data_format, - returns_data_format, function_type, doc. - - Raises RuntimeError on failure (mapped to result Err by - componentize-py). - """ - try: - func_names = list(_registry.functions.keys()) - descriptions = _registry._build_json_descriptions(func_names) - return json.dumps(descriptions) - except Exception as e: - tb = traceback.format_exc() - raise RuntimeError(f'{e}\n{tb}') - - def create_function( - self, - signature: str, - code: str, - replace: bool, - ) -> None: - """Register a function from its signature and Python source code. + return _call_function_accel( + colspec=arg_types, + returns=return_types, + data=input_data, + func=func, + ) - Returns None on success (mapped to result Ok(()) by componentize-py). - Raises RuntimeError on failure (mapped to result Err by - componentize-py). - """ - try: - _registry.create_function(signature, code, replace) - except Exception as e: - tb = traceback.format_exc() - raise RuntimeError(f'{e}\n{tb}') + row_ids, rows = _load_rowdat_1(arg_types, input_data) + results = [] + for row in rows: + result = func(*row) + if not isinstance(result, tuple): + result = [result] + results.append(list(result)) + return bytes(_dump_rowdat_1(return_types, row_ids, results)) + + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'Error calling {name}: {e}\n{tb}') + + +def describe_functions_json(registry: FunctionRegistry) -> str: + """Serialize all function descriptions as a JSON array string.""" + func_names = list(registry.functions.keys()) + descriptions = registry._build_json_descriptions(func_names) + return json.dumps(descriptions) diff --git a/singlestoredb/functions/ext/collocated/server.py b/singlestoredb/functions/ext/collocated/server.py new file mode 100644 index 000000000..ec5e31253 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/server.py @@ -0,0 +1,365 @@ +""" +Server lifecycle: accept loop, thread pool, shutdown. + +Mirrors the Rust wasm-udf-server architecture with a ThreadPoolExecutor +for concurrent request handling and a SharedRegistry with generation- +counter caching for thread-safe live reload. +""" +import importlib +import logging +import multiprocessing +import os +import select +import signal +import socket +import sys +import threading +import traceback +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from .connection import handle_connection +from .registry import FunctionRegistry + +logger = logging.getLogger('collocated.server') + + +class SharedRegistry: + """Thread-safe wrapper around FunctionRegistry with generation caching. + + Each worker thread caches a (generation, FunctionRegistry) pair in + thread-local storage. When @@register bumps the generation, workers + create a fresh registry and replay all code blocks on next call. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._generation: int = 0 + self._code_blocks: List[Tuple[str, str, bool]] = [] + self._base_registry: Optional[FunctionRegistry] = None + self._local = threading.local() + + def set_base_registry(self, registry: FunctionRegistry) -> None: + """Set the base registry (after initial module import + init).""" + with self._lock: + self._base_registry = registry + + @property + def generation(self) -> int: + return self._generation + + def create_function( + self, + signature_json: str, + code: str, + replace: bool, + ) -> List[str]: + """Register a new function and bump the generation counter. + + Thread-safe: acquires the lock, validates via a temporary + registry, stores the code block, and increments generation. + """ + with self._lock: + # Validate on a temporary registry first + test_registry = self._build_fresh_registry() + new_names = test_registry.create_function( + signature_json, code, replace, + ) + # Success: store the code block and bump generation + self._code_blocks.append((signature_json, code, replace)) + self._generation += 1 + logger.info( + f'SharedRegistry: generation={self._generation}, ' + f'code_blocks={len(self._code_blocks)}', + ) + return new_names + + def get_thread_local_registry(self) -> FunctionRegistry: + """Get or refresh the thread-local cached registry. + + Cheap int comparison on the hot path; only rebuilds on + generation mismatch. + """ + cached = getattr(self._local, 'cached', None) + if cached is not None: + cached_gen, cached_reg = cached + if cached_gen == self._generation: + return cached_reg + + # Rebuild from base + code blocks + with self._lock: + registry = self._build_fresh_registry() + gen = self._generation + + self._local.cached = (gen, registry) + return registry + + def _build_fresh_registry(self) -> FunctionRegistry: + """Build a fresh registry with base functions + all code blocks. + + Must be called with self._lock held. + """ + registry = FunctionRegistry() + # Copy base functions + if self._base_registry is not None: + registry.functions = dict(self._base_registry.functions) + # Replay code blocks + for sig_json, code, replace in self._code_blocks: + registry.create_function(sig_json, code, replace) + return registry + + +class Server: + """Collocated UDF server with Unix socket + thread pool.""" + + def __init__(self, config: Dict[str, Any]) -> None: + self.config = config + self.shared_registry = SharedRegistry() + self.shutdown_event = threading.Event() + + def run(self) -> None: + """Run the server: import modules, bind socket, accept loop.""" + # 1. Import user modules & initialize registry + registry = self._initialize_registry() + self.shared_registry.set_base_registry(registry) + + # 2. Create & bind Unix socket + server_sock = self._bind_socket() + + # 3. Determine worker count and process mode + n_workers = self.config.get('n_workers', 0) + if n_workers <= 0: + n_workers = os.cpu_count() or 4 + + process_mode = self.config.get('process_mode', 'process') + + # 4. Signal handling (main process) + def _signal_handler(signum: int, frame: Any) -> None: + logger.info(f'Received signal {signum}, shutting down...') + self.shutdown_event.set() + + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + # 5. Dispatch to mode-specific loop + sock_path = self.config['socket'] + try: + if process_mode == 'process': + self._run_process_mode(server_sock, n_workers) + else: + self._run_thread_mode(server_sock, n_workers) + finally: + server_sock.close() + try: + os.unlink(sock_path) + except OSError: + pass + logger.info('Server stopped.') + + def _bind_socket(self) -> socket.socket: + """Create, bind, and listen on the Unix domain socket.""" + sock_path = self.config['socket'] + if os.path.exists(sock_path): + os.unlink(sock_path) + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(sock_path) + os.chmod(sock_path, 0o600) + + backlog = self.config.get('max_connections', 32) + server_sock.listen(backlog) + logger.info(f'Listening on {sock_path} (backlog={backlog})') + return server_sock + + def _run_thread_mode( + self, + server_sock: socket.socket, + n_workers: int, + ) -> None: + """Accept loop using a ThreadPoolExecutor.""" + pool = ThreadPoolExecutor(max_workers=n_workers) + logger.info(f'Thread pool: {n_workers} workers') + + try: + while not self.shutdown_event.is_set(): + readable, _, _ = select.select( + [server_sock], [], [], 0.1, + ) + if not readable: + continue + + conn, _ = server_sock.accept() + pool.submit( + handle_connection, + conn, + self.shared_registry, + self.shutdown_event, + ) + finally: + logger.info('Shutting down thread pool...') + pool.shutdown(wait=True) + + def _run_process_mode( + self, + server_sock: socket.socket, + n_workers: int, + ) -> None: + """Pre-fork worker pool for true CPU parallelism.""" + ctx = multiprocessing.get_context('fork') + workers: Dict[int, multiprocessing.process.BaseProcess] = {} + + def _spawn_worker( + worker_id: int, + ) -> multiprocessing.process.BaseProcess: + p = ctx.Process( + target=self._worker_process_main, + args=(server_sock, worker_id), + daemon=True, + ) + p.start() + logger.info( + f'Started worker {worker_id} (pid={p.pid})', + ) + return p + + # Fork initial workers + logger.info( + f'Process pool: spawning {n_workers} workers', + ) + for i in range(n_workers): + workers[i] = _spawn_worker(i) + + # Monitor loop: restart dead workers + try: + while not self.shutdown_event.is_set(): + self.shutdown_event.wait(timeout=0.5) + for wid, proc in list(workers.items()): + if not proc.is_alive(): + exitcode = proc.exitcode + if not self.shutdown_event.is_set(): + logger.warning( + f'Worker {wid} (pid={proc.pid}) ' + f'exited with code {exitcode}, ' + f'restarting...', + ) + workers[wid] = _spawn_worker(wid) + finally: + logger.info('Shutting down worker processes...') + # Signal all workers to stop + for wid, proc in workers.items(): + if proc.is_alive(): + assert proc.pid is not None + os.kill(proc.pid, signal.SIGTERM) + + # Wait for graceful exit + for wid, proc in workers.items(): + proc.join(timeout=5.0) + if proc.is_alive(): + logger.warning( + f'Worker {wid} (pid={proc.pid}) ' + f'did not exit, terminating...', + ) + proc.terminate() + proc.join(timeout=2.0) + + def _worker_process_main( + self, + server_sock: socket.socket, + worker_id: int, + ) -> None: + """Entry point for each forked worker process.""" + try: + # Each worker gets its own registry and shutdown event + local_shared = SharedRegistry() + local_registry = FunctionRegistry() + local_registry.initialize() + local_shared.set_base_registry(local_registry) + + local_shutdown = threading.Event() + + def _worker_signal_handler( + signum: int, + frame: Any, + ) -> None: + local_shutdown.set() + + signal.signal(signal.SIGTERM, _worker_signal_handler) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Set non-blocking so accept() raises BlockingIOError + # instead of blocking when another worker wins the race. + # O_NONBLOCK is on the open file description (shared across + # forked processes), but that's fine — all workers want + # non-blocking accept and the parent doesn't call accept. + server_sock.setblocking(False) + + logger.info( + f'Worker {worker_id} (pid={os.getpid()}) ready, ' + f'{len(local_registry.functions)} function(s)', + ) + + # Accept loop + while not local_shutdown.is_set(): + readable, _, _ = select.select( + [server_sock], [], [], 0.1, + ) + if not readable: + continue + + try: + conn, _ = server_sock.accept() + except BlockingIOError: + # Another worker won the accept race + continue + except OSError: + if local_shutdown.is_set(): + break + raise + + handle_connection( + conn, + local_shared, + local_shutdown, + ) + except Exception: + logger.error( + f'Worker {worker_id} crashed:\n' + f'{traceback.format_exc()}', + ) + raise + + def _initialize_registry(self) -> FunctionRegistry: + """Import the extension module and discover @udf functions.""" + extension = self.config['extension'] + extension_path = self.config.get('extension_path', '') + + # Prepend extension path directories to sys.path + if extension_path: + for p in reversed(extension_path.split(':')): + p = p.strip() + if p and p not in sys.path: + sys.path.insert(0, p) + logger.info(f'Added to sys.path: {p}') + + # Import the extension module + logger.info(f'Importing extension module: {extension}') + importlib.import_module(extension) + + # Initialize registry (discovers @udf functions from sys.modules) + registry = FunctionRegistry() + registry.initialize() + + func_count = len(registry.functions) + if func_count == 0: + raise RuntimeError( + f'No @udf functions found after importing {extension!r}', + ) + logger.info(f'Discovered {func_count} function(s)') + for name in sorted(registry.functions): + logger.info(f' function: {name}') + + return registry diff --git a/singlestoredb/functions/ext/collocated/wasm.py b/singlestoredb/functions/ext/collocated/wasm.py new file mode 100644 index 000000000..fad07b4aa --- /dev/null +++ b/singlestoredb/functions/ext/collocated/wasm.py @@ -0,0 +1,60 @@ +""" +Thin WIT adapter over FunctionRegistry. + +This module provides the FunctionHandler class that implements the +singlestore:udf/function-handler WIT interface by delegating to the +shared FunctionRegistry in registry.py. +""" +import logging +import traceback + +from .registry import _has_call_accel +from .registry import call_function +from .registry import describe_functions_json +from .registry import FunctionRegistry + +logger = logging.getLogger('udf_handler') + +# Global registry instance (used by WASM component runtime) +_registry = FunctionRegistry() + + +class FunctionHandler: + """Implementation of the singlestore:udf/function-handler interface.""" + + def initialize(self) -> None: + """Initialize and discover UDF functions from loaded modules.""" + if _has_call_accel: + logger.info('Using accelerated C call_function_accel loop') + else: + logger.info('Using pure Python call_function loop') + _registry.initialize() + + def call_function(self, name: str, input_data: bytes) -> bytes: + """Call a function by its registered name.""" + return call_function(_registry, name, input_data) + + def describe_functions(self) -> str: + """Describe all functions as a JSON array. + + Returns a JSON string containing an array of function + description objects. + """ + try: + return describe_functions_json(_registry) + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'{e}\n{tb}') + + def create_function( + self, + signature: str, + code: str, + replace: bool, + ) -> None: + """Register a function from its signature and Python source code.""" + try: + _registry.create_function(signature, code, replace) + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'{e}\n{tb}') diff --git a/singlestoredb/functions/ext/wasm/__init__.py b/singlestoredb/functions/ext/wasm/__init__.py deleted file mode 100644 index e69de29bb..000000000 From e4d8627da62c5aeaff03b158911e723c51729abb Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Tue, 31 Mar 2026 10:15:59 -0500 Subject: [PATCH 11/19] Fix @@register propagation in collocated server process mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each forked worker previously created its own independent SharedRegistry and FunctionRegistry. When @@register arrived at a worker, only that worker's local registry was updated — the main process and sibling workers never learned about the new function. Add Unix pipe-based IPC (matching the R UDF server fix): each worker gets a pipe back to the main process. When a worker handles @@register, it writes the registration payload to its pipe. The main process reads it via select.poll(), applies the registration to its own SharedRegistry, then kills and re-forks all workers so they inherit the updated state. Thread mode is unaffected — pipe_write_fd is None and the pipe write is a no-op. Co-Authored-By: Claude Opus 4.6 --- .../functions/ext/collocated/connection.py | 10 +- .../functions/ext/collocated/control.py | 23 +- .../functions/ext/collocated/server.py | 196 ++++++++++++++---- 3 files changed, 189 insertions(+), 40 deletions(-) diff --git a/singlestoredb/functions/ext/collocated/connection.py b/singlestoredb/functions/ext/collocated/connection.py index 47921485d..931ec1780 100644 --- a/singlestoredb/functions/ext/collocated/connection.py +++ b/singlestoredb/functions/ext/collocated/connection.py @@ -39,10 +39,13 @@ def handle_connection( conn: socket.socket, shared_registry: SharedRegistry, shutdown_event: threading.Event, + pipe_write_fd: int | None = None, ) -> None: """Handle a single client connection (runs in a thread pool worker).""" try: - _handle_connection_inner(conn, shared_registry, shutdown_event) + _handle_connection_inner( + conn, shared_registry, shutdown_event, pipe_write_fd, + ) except Exception: logger.error(f'Connection error:\n{traceback.format_exc()}') finally: @@ -56,6 +59,7 @@ def _handle_connection_inner( conn: socket.socket, shared_registry: SharedRegistry, shutdown_event: threading.Event, + pipe_write_fd: int | None = None, ) -> None: """Inner connection handler (may raise).""" # --- Handshake --- @@ -87,6 +91,7 @@ def _handle_connection_inner( logger.info(f"Received control signal '{function_name}'") _handle_control_signal( conn, function_name, input_fd, output_fd, shared_registry, + pipe_write_fd, ) return @@ -104,6 +109,7 @@ def _handle_control_signal( input_fd: int, output_fd: int, shared_registry: SharedRegistry, + pipe_write_fd: int | None = None, ) -> None: """Handle a @@-prefixed control signal (one-shot request-response).""" try: @@ -126,7 +132,7 @@ def _handle_control_signal( # Dispatch result = dispatch_control_signal( - signal_name, request_data, shared_registry, + signal_name, request_data, shared_registry, pipe_write_fd, ) if result.ok: diff --git a/singlestoredb/functions/ext/collocated/control.py b/singlestoredb/functions/ext/collocated/control.py index 128414217..b2d5c59da 100644 --- a/singlestoredb/functions/ext/collocated/control.py +++ b/singlestoredb/functions/ext/collocated/control.py @@ -29,6 +29,7 @@ def dispatch_control_signal( signal_name: str, request_data: bytes, shared_registry: SharedRegistry, + pipe_write_fd: int | None = None, ) -> ControlResult: """Dispatch a control signal to the appropriate handler.""" try: @@ -37,7 +38,9 @@ def dispatch_control_signal( elif signal_name == '@@functions': return _handle_functions(shared_registry) elif signal_name == '@@register': - return _handle_register(request_data, shared_registry) + return _handle_register( + request_data, shared_registry, pipe_write_fd, + ) else: return ControlResult( ok=False, @@ -62,8 +65,14 @@ def _handle_functions(shared_registry: SharedRegistry) -> ControlResult: def _handle_register( request_data: bytes, shared_registry: SharedRegistry, + pipe_write_fd: int | None = None, ) -> ControlResult: - """Handle @@register: register a new function dynamically.""" + """Handle @@register: register a new function dynamically. + + If ``pipe_write_fd`` is not None (process mode), the registration + payload is written to the pipe so the main process can update its + own registry and re-fork all workers. + """ if not request_data: return ControlResult(ok=False, data='Missing registration payload') @@ -111,5 +120,15 @@ def _handle_register( except Exception as e: return ControlResult(ok=False, data=str(e)) + # Notify main process so it can re-fork workers with updated state + if pipe_write_fd is not None: + from .server import _write_pipe_message + payload = json.dumps({ + 'signature_json': signature, + 'code': func_body, + 'replace': replace, + }).encode() + _write_pipe_message(pipe_write_fd, payload) + logger.info(f"@@register: added function '{function_name}'") return ControlResult(ok=True, data='{"status":"ok"}') diff --git a/singlestoredb/functions/ext/collocated/server.py b/singlestoredb/functions/ext/collocated/server.py index ec5e31253..fdb0b75d4 100644 --- a/singlestoredb/functions/ext/collocated/server.py +++ b/singlestoredb/functions/ext/collocated/server.py @@ -6,12 +6,14 @@ counter caching for thread-safe live reload. """ import importlib +import json import logging import multiprocessing import os import select import signal import socket +import struct import sys import threading import traceback @@ -28,6 +30,40 @@ logger = logging.getLogger('collocated.server') +def _read_pipe_message(fd: int) -> Optional[bytes]: + """Read a length-prefixed message from a pipe fd. + + Wire format: [u32 LE length][payload]. + Returns None on EOF or short read. + """ + try: + len_buf = b'' + while len(len_buf) < 4: + chunk = os.read(fd, 4 - len(len_buf)) + if not chunk: + return None + len_buf += chunk + length = struct.unpack(' None: + """Write a length-prefixed message to a pipe fd. + + Wire format: [u32 LE length][payload]. + """ + header = struct.pack(' None: - """Pre-fork worker pool for true CPU parallelism.""" - ctx = multiprocessing.get_context('fork') - workers: Dict[int, multiprocessing.process.BaseProcess] = {} + """Pre-fork worker pool for true CPU parallelism. - def _spawn_worker( - worker_id: int, - ) -> multiprocessing.process.BaseProcess: + Each worker gets a pipe back to the main process. When a worker + receives @@register, it writes the registration payload to its + pipe. The main process reads it, updates its own registry, then + kills and re-forks all workers so every worker has the updated + registry state. + """ + ctx = multiprocessing.get_context('fork') + # workers[wid] = (process, pipe_read_fd) + workers: Dict[ + int, + Tuple[multiprocessing.process.BaseProcess, int], + ] = {} + + def _spawn_worker(worker_id: int) -> Tuple[ + multiprocessing.process.BaseProcess, int, + ]: + pipe_r, pipe_w = os.pipe() p = ctx.Process( target=self._worker_process_main, - args=(server_sock, worker_id), + args=(server_sock, worker_id, pipe_w), daemon=True, ) p.start() + # Close the write end in the parent — only the child writes + os.close(pipe_w) logger.info( f'Started worker {worker_id} (pid={p.pid})', ) - return p + return p, pipe_r + + def _kill_all_workers() -> None: + """SIGTERM all workers, wait, then SIGKILL stragglers.""" + for wid, (proc, pipe_r) in workers.items(): + if proc.is_alive(): + assert proc.pid is not None + os.kill(proc.pid, signal.SIGTERM) + for wid, (proc, pipe_r) in workers.items(): + proc.join(timeout=5.0) + if proc.is_alive(): + logger.warning( + f'Worker {wid} (pid={proc.pid}) ' + f'did not exit, terminating...', + ) + proc.terminate() + proc.join(timeout=2.0) + # Close all pipe read fds + for wid, (proc, pipe_r) in workers.items(): + try: + os.close(pipe_r) + except OSError: + pass + + def _respawn_all_workers() -> None: + """Kill all workers and re-fork them with fresh state.""" + _kill_all_workers() + workers.clear() + for i in range(n_workers): + workers[i] = _spawn_worker(i) # Fork initial workers logger.info( @@ -233,11 +312,59 @@ def _spawn_worker( for i in range(n_workers): workers[i] = _spawn_worker(i) - # Monitor loop: restart dead workers + # Monitor loop using poll() over pipe read fds try: while not self.shutdown_event.is_set(): - self.shutdown_event.wait(timeout=0.5) - for wid, proc in list(workers.items()): + poller = select.poll() + fd_to_wid: Dict[int, int] = {} + for wid, (proc, pipe_r) in workers.items(): + poller.register( + pipe_r, select.POLLIN | select.POLLHUP, + ) + fd_to_wid[pipe_r] = wid + + events = poller.poll(500) # 500ms timeout + + registration_received = False + for fd, event in events: + if fd not in fd_to_wid: + continue + wid = fd_to_wid[fd] + + if event & select.POLLIN: + msg = _read_pipe_message(fd) + if msg is not None: + # Apply registration to main's registry + try: + body = json.loads(msg) + self.shared_registry.create_function( + body['signature_json'], + body['code'], + body['replace'], + ) + logger.info( + 'Main process: applied ' + '@@register from worker ' + f'{wid}, will re-fork all ' + 'workers', + ) + registration_received = True + except Exception: + logger.error( + 'Main process: failed to ' + 'apply @@register:\n' + f'{traceback.format_exc()}', + ) + elif event & select.POLLHUP: + # Worker died — will be respawned below + pass + + if registration_received: + _respawn_all_workers() + continue + + # Check for dead workers and respawn individually + for wid, (proc, pipe_r) in list(workers.items()): if not proc.is_alive(): exitcode = proc.exitcode if not self.shutdown_event.is_set(): @@ -246,39 +373,29 @@ def _spawn_worker( f'exited with code {exitcode}, ' f'restarting...', ) + try: + os.close(pipe_r) + except OSError: + pass workers[wid] = _spawn_worker(wid) finally: logger.info('Shutting down worker processes...') - # Signal all workers to stop - for wid, proc in workers.items(): - if proc.is_alive(): - assert proc.pid is not None - os.kill(proc.pid, signal.SIGTERM) - - # Wait for graceful exit - for wid, proc in workers.items(): - proc.join(timeout=5.0) - if proc.is_alive(): - logger.warning( - f'Worker {wid} (pid={proc.pid}) ' - f'did not exit, terminating...', - ) - proc.terminate() - proc.join(timeout=2.0) + _kill_all_workers() def _worker_process_main( self, server_sock: socket.socket, worker_id: int, + pipe_w: int, ) -> None: - """Entry point for each forked worker process.""" - try: - # Each worker gets its own registry and shutdown event - local_shared = SharedRegistry() - local_registry = FunctionRegistry() - local_registry.initialize() - local_shared.set_base_registry(local_registry) + """Entry point for each forked worker process. + Uses ``self.shared_registry`` inherited via fork (contains the + main process's current state). ``pipe_w`` is used to notify the + main process when @@register is handled so it can re-fork all + workers. + """ + try: local_shutdown = threading.Event() def _worker_signal_handler( @@ -297,9 +414,10 @@ def _worker_signal_handler( # non-blocking accept and the parent doesn't call accept. server_sock.setblocking(False) + registry = self.shared_registry.get_thread_local_registry() logger.info( f'Worker {worker_id} (pid={os.getpid()}) ready, ' - f'{len(local_registry.functions)} function(s)', + f'{len(registry.functions)} function(s)', ) # Accept loop @@ -322,8 +440,9 @@ def _worker_signal_handler( handle_connection( conn, - local_shared, + self.shared_registry, local_shutdown, + pipe_write_fd=pipe_w, ) except Exception: logger.error( @@ -331,6 +450,11 @@ def _worker_signal_handler( f'{traceback.format_exc()}', ) raise + finally: + try: + os.close(pipe_w) + except OSError: + pass def _initialize_registry(self) -> FunctionRegistry: """Import the extension module and discover @udf functions.""" From 1b90ab6797cc3fb8065d3c0eb38f7a3575269974 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 1 Apr 2026 10:26:18 -0500 Subject: [PATCH 12/19] Fix broken pipe in collocated UDF server under concurrent load Add poll()-based timeout to C recv_exact to avoid the interaction between Python's settimeout() (which sets O_NONBLOCK on the fd) and direct fd-level recv() in the C code. When the fd was non-blocking, recv() returned EAGAIN immediately when no data was available, which the C code treated as an error, closing the connection and causing EPIPE on the client side. - accel.c: Add optional timeout_ms parameter to recv_exact that uses poll(POLLIN) before each recv() call, raising TimeoutError on timeout. Also add mmap_read and mmap_write C helpers for fd-level I/O. - connection.py: Only call settimeout() for the Python fallback path; keep fd blocking for C accel path. Pass 100ms timeout to C recv_exact. Catch TimeoutError instead of socket.timeout. Replace select() loop with timeout-based recv. Add C accel paths for mmap read/write. Add optional per-request profiling via SINGLESTOREDB_UDF_PROFILE=1. - registry.py: Consolidate accel imports (mmap_read, mmap_write, recv_exact) under single _has_accel flag. - wasm.py: Update to use renamed _has_accel flag. Co-Authored-By: Claude Opus 4.6 --- accel.c | 155 +++++++++++++++ .../functions/ext/collocated/connection.py | 183 ++++++++++++++---- .../functions/ext/collocated/registry.py | 18 +- .../functions/ext/collocated/wasm.py | 4 +- 4 files changed, 315 insertions(+), 45 deletions(-) diff --git a/accel.c b/accel.c index 29d482bfc..562269e8d 100644 --- a/accel.c +++ b/accel.c @@ -1,7 +1,12 @@ #include +#include #include #include +#include +#include +#include +#include #include #ifndef Py_LIMITED_API @@ -5293,6 +5298,153 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k goto exit; } +/* + * mmap_read(fd, length) -> bytes + * + * Maps the given fd with MAP_SHARED|PROT_READ for `length` bytes, + * copies into a Python bytes object, and unmaps in a single C call. + * Eliminates Python mmap object creation/destruction overhead. + */ +static PyObject *accel_mmap_read(PyObject *self, PyObject *args) { + int fd; + Py_ssize_t length; + + if (!PyArg_ParseTuple(args, "in", &fd, &length)) + return NULL; + + if (length <= 0) { + return PyBytes_FromStringAndSize(NULL, 0); + } + + void *addr = mmap(NULL, (size_t)length, PROT_READ, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + + PyObject *result = PyBytes_FromStringAndSize((const char *)addr, length); + munmap(addr, (size_t)length); + return result; +} + +/* + * mmap_write(fd, data, min_size) -> None + * + * Writes `data` to the file descriptor, combining ftruncate + lseek + write + * into a single C call. If min_size > 0, ftruncate is called with + * max(min_size, len(data)); if min_size == 0, ftruncate is skipped + * (caller manages file size). + */ +static PyObject *accel_mmap_write(PyObject *self, PyObject *args) { + int fd; + const char *data; + Py_ssize_t data_len; + Py_ssize_t min_size; + + if (!PyArg_ParseTuple(args, "iy#n", &fd, &data, &data_len, &min_size)) + return NULL; + + if (min_size > 0) { + Py_ssize_t trunc_size = data_len > min_size ? data_len : min_size; + if (ftruncate(fd, (off_t)trunc_size) < 0) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + } + + if (lseek(fd, 0, SEEK_SET) < 0) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + + const char *p = data; + Py_ssize_t remaining = data_len; + while (remaining > 0) { + ssize_t written = write(fd, p, (size_t)remaining); + if (written < 0) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + p += written; + remaining -= written; + } + + Py_RETURN_NONE; +} + +/* + * recv_exact(fd, n, timeout_ms=-1) -> bytes or None + * + * Receives exactly `n` bytes from a socket fd using blocking recv. + * Returns None on EOF (peer closed). Operates on raw fd to avoid + * Python socket object overhead. Releases the GIL during recv. + * + * When timeout_ms >= 0, uses poll() before each recv() to wait for + * data with a timeout. Raises TimeoutError on timeout. This allows + * the fd to remain in blocking mode while still supporting timeouts, + * avoiding the interaction between Python's settimeout() (which sets + * O_NONBLOCK) and direct fd-level recv(). + */ +static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { + int fd, timeout_ms = -1; + Py_ssize_t n; + + if (!PyArg_ParseTuple(args, "in|i", &fd, &n, &timeout_ms)) + return NULL; + + if (n <= 0) { + return PyBytes_FromStringAndSize(NULL, 0); + } + + char *buf = (char *)malloc((size_t)n); + if (!buf) { + PyErr_NoMemory(); + return NULL; + } + + Py_ssize_t pos = 0; + while (pos < n) { + if (timeout_ms >= 0) { + struct pollfd pfd = {fd, POLLIN, 0}; + int poll_rc; + Py_BEGIN_ALLOW_THREADS + poll_rc = poll(&pfd, 1, timeout_ms); + Py_END_ALLOW_THREADS + if (poll_rc == 0) { + free(buf); + PyErr_SetString(PyExc_TimeoutError, "recv_exact timed out"); + return NULL; + } + if (poll_rc < 0) { + free(buf); + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + } + + ssize_t received; + Py_BEGIN_ALLOW_THREADS + received = recv(fd, buf + pos, (size_t)(n - pos), 0); + Py_END_ALLOW_THREADS + + if (received < 0) { + free(buf); + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + if (received == 0) { + /* EOF */ + free(buf); + Py_RETURN_NONE; + } + pos += received; + } + + PyObject *result = PyBytes_FromStringAndSize(buf, n); + free(buf); + return result; +} + static PyMethodDef PyMySQLAccelMethods[] = { {"read_rowdata_packet", (PyCFunction)read_rowdata_packet, METH_VARARGS | METH_KEYWORDS, "PyMySQL row data packet reader"}, {"dump_rowdat_1", (PyCFunction)dump_rowdat_1, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions"}, @@ -5300,6 +5452,9 @@ static PyMethodDef PyMySQLAccelMethods[] = { {"dump_rowdat_1_numpy", (PyCFunction)dump_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions which takes numpy.arrays"}, {"load_rowdat_1_numpy", (PyCFunction)load_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions which creates numpy.arrays"}, {"call_function_accel", (PyCFunction)call_function_accel, METH_VARARGS | METH_KEYWORDS, "Combined load/call/dump for UDF function calls"}, + {"mmap_read", (PyCFunction)accel_mmap_read, METH_VARARGS, "mmap read: maps fd, copies data, unmaps"}, + {"mmap_write", (PyCFunction)accel_mmap_write, METH_VARARGS, "mmap write: ftruncate+lseek+write in one call"}, + {"recv_exact", (PyCFunction)accel_recv_exact, METH_VARARGS, "Receive exactly N bytes from a socket fd"}, {NULL, NULL, 0, NULL} }; diff --git a/singlestoredb/functions/ext/collocated/connection.py b/singlestoredb/functions/ext/collocated/connection.py index 931ec1780..7c8149d1b 100644 --- a/singlestoredb/functions/ext/collocated/connection.py +++ b/singlestoredb/functions/ext/collocated/connection.py @@ -10,14 +10,18 @@ import logging import mmap import os -import select import socket import struct import threading +import time import traceback from typing import TYPE_CHECKING from .control import dispatch_control_signal +from .registry import _has_accel +from .registry import _mmap_read +from .registry import _mmap_write +from .registry import _recv_exact as _c_recv_exact from .registry import call_function if TYPE_CHECKING: @@ -34,6 +38,12 @@ # Minimum output mmap size to avoid repeated ftruncate _MIN_OUTPUT_SIZE = 128 * 1024 +# Pre-pack the status OK header prefix to avoid per-request struct.pack +_STATUS_OK_PREFIX = struct.pack(' 0: - mem = mmap.mmap( - input_fd, length, mmap.MAP_SHARED, mmap.PROT_READ, - ) - try: - request_data = mem[:length] - finally: - mem.close() + if _has_accel: + request_data = _mmap_read(input_fd, length) + else: + mem = mmap.mmap( + input_fd, length, mmap.MAP_SHARED, mmap.PROT_READ, + ) + try: + request_data = bytes(mem[:length]) + finally: + mem.close() # Dispatch result = dispatch_control_signal( @@ -139,9 +152,15 @@ def _handle_control_signal( # Write response to output mmap response_bytes = result.data.encode('utf8') response_size = len(response_bytes) - os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size)) - os.lseek(output_fd, 0, os.SEEK_SET) - os.write(output_fd, response_bytes) + if _has_accel: + _mmap_write( + output_fd, response_bytes, + max(_MIN_OUTPUT_SIZE, response_size), + ) + else: + os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size)) + os.lseek(output_fd, 0, os.SEEK_SET) + os.write(output_fd, response_bytes) # Send [status=200, size] conn.sendall(struct.pack(' current_output_size: - os.ftruncate(output_fd, needed) - current_output_size = needed - os.lseek(output_fd, 0, os.SEEK_SET) - os.write(output_fd, output_data) + if profile: + t0 = time.monotonic() + if use_accel: + needed = max(_MIN_OUTPUT_SIZE, response_size) + if needed > current_output_size: + _mmap_write(output_fd, output_data, needed) + current_output_size = needed + else: + _mmap_write(output_fd, output_data, 0) + else: + needed = max(_MIN_OUTPUT_SIZE, response_size) + if needed > current_output_size: + os.ftruncate(output_fd, needed) + current_output_size = needed + os.lseek(output_fd, 0, os.SEEK_SET) + os.write(output_fd, output_data) + if profile: + t_mmap_write += time.monotonic() - t0 # Send [status=200, size] - conn.sendall(struct.pack(' 0: + t_total = ( + t_recv + t_mmap_read + t_call + t_mmap_write + t_send + ) / n_requests * 1e6 + logger.info( + f"PROFILE '{function_name}' " + f'n={n_requests} ' + f'recv={t_recv / n_requests * 1e6:.1f}us ' + f'mmap_read={t_mmap_read / n_requests * 1e6:.1f}us ' + f'call={t_call / n_requests * 1e6:.1f}us ' + f'mmap_write={t_mmap_write / n_requests * 1e6:.1f}us ' + f'send={t_send / n_requests * 1e6:.1f}us ' + f'total={t_total:.1f}us', + ) + -def _recv_exact(sock: socket.socket, n: int) -> bytes | None: +def _recv_exact_py(sock: socket.socket, n: int) -> bytes | None: """Receive exactly n bytes, or return None on EOF.""" - buf = bytearray() - while len(buf) < n: - chunk = sock.recv(n - len(buf)) - if not chunk: + buf = bytearray(n) + view = memoryview(buf) + pos = 0 + while pos < n: + nbytes = sock.recv_into(view[pos:]) + if nbytes == 0: return None - buf.extend(chunk) + pos += nbytes return bytes(buf) diff --git a/singlestoredb/functions/ext/collocated/registry.py b/singlestoredb/functions/ext/collocated/registry.py index f44c83be8..8b931dc1d 100644 --- a/singlestoredb/functions/ext/collocated/registry.py +++ b/singlestoredb/functions/ext/collocated/registry.py @@ -29,9 +29,19 @@ try: from _singlestoredb_accel import call_function_accel as _call_function_accel - _has_call_accel = True -except Exception: - _has_call_accel = False + from _singlestoredb_accel import mmap_read as _mmap_read + from _singlestoredb_accel import mmap_write as _mmap_write + from _singlestoredb_accel import recv_exact as _recv_exact + _has_accel = True + logging.getLogger(__name__).info('_singlestoredb_accel loaded successfully') +except Exception as e: + _has_accel = False + _mmap_read = None + _mmap_write = None + _recv_exact = None + logging.getLogger(__name__).warning( + '_singlestoredb_accel failed to load: %s', e, + ) class _TracingFormatter(logging.Formatter): @@ -434,7 +444,7 @@ def call_function( return_types = func_info['return_types'] try: - if _has_call_accel: + if _has_accel: return _call_function_accel( colspec=arg_types, returns=return_types, diff --git a/singlestoredb/functions/ext/collocated/wasm.py b/singlestoredb/functions/ext/collocated/wasm.py index fad07b4aa..f6898a3b3 100644 --- a/singlestoredb/functions/ext/collocated/wasm.py +++ b/singlestoredb/functions/ext/collocated/wasm.py @@ -8,7 +8,7 @@ import logging import traceback -from .registry import _has_call_accel +from .registry import _has_accel from .registry import call_function from .registry import describe_functions_json from .registry import FunctionRegistry @@ -24,7 +24,7 @@ class FunctionHandler: def initialize(self) -> None: """Initialize and discover UDF functions from loaded modules.""" - if _has_call_accel: + if _has_accel: logger.info('Using accelerated C call_function_accel loop') else: logger.info('Using pure Python call_function loop') From c732dff705cbf53b22436322bc09b5499ccfe185 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 1 Apr 2026 13:39:10 -0500 Subject: [PATCH 13/19] Guard np.dtype check in normalize_dtype for environments without numpy When numpy is not available (e.g., WASM), the `np` name is undefined. The has_numpy flag was already used elsewhere but this check was missed when the numpy_stub was removed. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/signature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index cf2b5d017..4ec89a480 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -332,7 +332,7 @@ def normalize_dtype(dtype: Any) -> str: if isinstance(dtype, str): return sql_to_dtype(dtype) - if typing.get_origin(dtype) is np.dtype: + if has_numpy and typing.get_origin(dtype) is np.dtype: dtype = typing.get_args(dtype)[0] # Specific types From 1b62a6c847841b4c7e827b6dca4e067f0b6cff08 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 1 Apr 2026 13:51:48 -0500 Subject: [PATCH 14/19] Guard WASI-incompatible POSIX APIs in accel.c with #ifndef __wasi__ The mmap_read, mmap_write, and recv_exact functions use poll.h, sys/mman.h, and sys/socket.h which are unavailable in WASI. Wrap these includes, function bodies, and PyMethodDef entries with #ifndef __wasi__ guards so the C extension compiles for wasm32-wasip2. The core call_function_accel optimization remains available. Co-Authored-By: Claude Opus 4.6 --- accel.c | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/accel.c b/accel.c index 562269e8d..a7246aeef 100644 --- a/accel.c +++ b/accel.c @@ -1,11 +1,15 @@ #include +#ifndef __wasi__ #include +#endif #include #include #include +#ifndef __wasi__ #include #include +#endif #include #include @@ -5298,6 +5302,7 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k goto exit; } +#ifndef __wasi__ /* * mmap_read(fd, length) -> bytes * @@ -5444,6 +5449,7 @@ static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { free(buf); return result; } +#endif /* !__wasi__ */ static PyMethodDef PyMySQLAccelMethods[] = { {"read_rowdata_packet", (PyCFunction)read_rowdata_packet, METH_VARARGS | METH_KEYWORDS, "PyMySQL row data packet reader"}, @@ -5452,9 +5458,11 @@ static PyMethodDef PyMySQLAccelMethods[] = { {"dump_rowdat_1_numpy", (PyCFunction)dump_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions which takes numpy.arrays"}, {"load_rowdat_1_numpy", (PyCFunction)load_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions which creates numpy.arrays"}, {"call_function_accel", (PyCFunction)call_function_accel, METH_VARARGS | METH_KEYWORDS, "Combined load/call/dump for UDF function calls"}, +#ifndef __wasi__ {"mmap_read", (PyCFunction)accel_mmap_read, METH_VARARGS, "mmap read: maps fd, copies data, unmaps"}, {"mmap_write", (PyCFunction)accel_mmap_write, METH_VARARGS, "mmap write: ftruncate+lseek+write in one call"}, {"recv_exact", (PyCFunction)accel_recv_exact, METH_VARARGS, "Receive exactly N bytes from a socket fd"}, +#endif {NULL, NULL, 0, NULL} }; From c65793a0ccd6ea1cdf8a619493751bd59b76232f Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 1 Apr 2026 14:36:40 -0500 Subject: [PATCH 15/19] Call setup_logging() in FunctionHandler.initialize() for WASM handler Without this, the accel status log messages ("Using accelerated C call_function_accel loop" / "Using pure Python call_function loop") are silently dropped because no logging handler is configured in the WASM handler path. setup_logging() was only called from __main__.py (collocated server CLI). Co-Authored-By: Claude Opus 4.6 --- singlestoredb/functions/ext/collocated/wasm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/singlestoredb/functions/ext/collocated/wasm.py b/singlestoredb/functions/ext/collocated/wasm.py index f6898a3b3..ad0625877 100644 --- a/singlestoredb/functions/ext/collocated/wasm.py +++ b/singlestoredb/functions/ext/collocated/wasm.py @@ -12,6 +12,7 @@ from .registry import call_function from .registry import describe_functions_json from .registry import FunctionRegistry +from .registry import setup_logging logger = logging.getLogger('udf_handler') @@ -24,6 +25,7 @@ class FunctionHandler: def initialize(self) -> None: """Initialize and discover UDF functions from loaded modules.""" + setup_logging() if _has_accel: logger.info('Using accelerated C call_function_accel loop') else: From b66653a327622c6fee28400d5472c59bb4620b7a Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Wed, 1 Apr 2026 14:49:53 -0500 Subject: [PATCH 16/19] Add WASI stubs for mmap_read/mmap_write/recv_exact and fix accel logging The _singlestoredb_accel C extension ifdef'd out the mmap and socket functions for __wasi__ builds, but registry.py imports all four symbols (call_function_accel, mmap_read, mmap_write, recv_exact) in a single try block. The missing exports caused the entire import to fail, silently falling back to the pure Python call_function loop. Add #else stubs that raise NotImplementedError if called, so the symbols are importable and call_function_accel works in WASM. Also capture the accel import error and log it in initialize() for future diagnostics. Co-Authored-By: Claude Opus 4.6 --- accel.c | 19 +++++++++++++++++-- .../functions/ext/collocated/registry.py | 2 ++ .../functions/ext/collocated/wasm.py | 6 ++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/accel.c b/accel.c index a7246aeef..2062fcba8 100644 --- a/accel.c +++ b/accel.c @@ -5449,6 +5449,23 @@ static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { free(buf); return result; } +#else /* __wasi__ stubs — importable but raise NotImplementedError if called */ + +static PyObject *accel_mmap_read(PyObject *self, PyObject *args) { + PyErr_SetString(PyExc_NotImplementedError, "mmap_read is not available in WASM"); + return NULL; +} + +static PyObject *accel_mmap_write(PyObject *self, PyObject *args) { + PyErr_SetString(PyExc_NotImplementedError, "mmap_write is not available in WASM"); + return NULL; +} + +static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { + PyErr_SetString(PyExc_NotImplementedError, "recv_exact is not available in WASM"); + return NULL; +} + #endif /* !__wasi__ */ static PyMethodDef PyMySQLAccelMethods[] = { @@ -5458,11 +5475,9 @@ static PyMethodDef PyMySQLAccelMethods[] = { {"dump_rowdat_1_numpy", (PyCFunction)dump_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions which takes numpy.arrays"}, {"load_rowdat_1_numpy", (PyCFunction)load_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions which creates numpy.arrays"}, {"call_function_accel", (PyCFunction)call_function_accel, METH_VARARGS | METH_KEYWORDS, "Combined load/call/dump for UDF function calls"}, -#ifndef __wasi__ {"mmap_read", (PyCFunction)accel_mmap_read, METH_VARARGS, "mmap read: maps fd, copies data, unmaps"}, {"mmap_write", (PyCFunction)accel_mmap_write, METH_VARARGS, "mmap write: ftruncate+lseek+write in one call"}, {"recv_exact", (PyCFunction)accel_recv_exact, METH_VARARGS, "Receive exactly N bytes from a socket fd"}, -#endif {NULL, NULL, 0, NULL} }; diff --git a/singlestoredb/functions/ext/collocated/registry.py b/singlestoredb/functions/ext/collocated/registry.py index 8b931dc1d..822921e16 100644 --- a/singlestoredb/functions/ext/collocated/registry.py +++ b/singlestoredb/functions/ext/collocated/registry.py @@ -27,6 +27,7 @@ from singlestoredb.functions.signature import get_signature from singlestoredb.mysql.constants import FIELD_TYPE as ft +_accel_error: Optional[str] = None try: from _singlestoredb_accel import call_function_accel as _call_function_accel from _singlestoredb_accel import mmap_read as _mmap_read @@ -36,6 +37,7 @@ logging.getLogger(__name__).info('_singlestoredb_accel loaded successfully') except Exception as e: _has_accel = False + _accel_error = str(e) _mmap_read = None _mmap_write = None _recv_exact = None diff --git a/singlestoredb/functions/ext/collocated/wasm.py b/singlestoredb/functions/ext/collocated/wasm.py index ad0625877..7e8dda6d8 100644 --- a/singlestoredb/functions/ext/collocated/wasm.py +++ b/singlestoredb/functions/ext/collocated/wasm.py @@ -8,6 +8,7 @@ import logging import traceback +from .registry import _accel_error from .registry import _has_accel from .registry import call_function from .registry import describe_functions_json @@ -30,6 +31,11 @@ def initialize(self) -> None: logger.info('Using accelerated C call_function_accel loop') else: logger.info('Using pure Python call_function loop') + if _accel_error: + logger.warning( + '_singlestoredb_accel failed to load: %s', + _accel_error, + ) _registry.initialize() def call_function(self, name: str, input_data: bytes) -> bytes: From 34e24527922d6924699d1355b9877bd72e8412bc Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 2 Apr 2026 10:37:24 -0500 Subject: [PATCH 17/19] Address PR #121 review comments: memory safety, correctness, hardening accel.c: - Replace empty TODO type stubs with NotImplementedError raises - Add CHECK_REMAINING macro for bounds checking on buffer reads - Replace unaligned pointer-cast reads with memcpy for WASM/ARM safety - Fix double-decref in output error paths (set to NULL before goto) - Fix Py_None reference leak by removing pre-switch INCREF - Fix MYSQL_TYPE_NULL consuming an extra byte from next column - Add PyErr_Format in default switch cases - Add PyErr_Occurred() checks after PyLong/PyFloat conversions Python: - Align list/tuple multi-return handling in registry.py with C path - Add _write_all_fd helper for partial os.write() handling - Harden handshake recvmsg: name length bound, ancdata validation, MSG_CTRUNC check, FD cleanup on error - Wrap get_context('fork') with platform safety error - Narrow events.py exception catch to (ImportError, OSError) - Fix _iquery DataFrame check ordering (check before list()) - Expand setblocking(False) warning comment - Update WIT and wasm.py docstrings for code parameter Co-Authored-By: Claude Opus 4.6 --- accel.c | 239 +++++++++++++----- singlestoredb/connection.py | 9 +- .../functions/ext/collocated/connection.py | 69 ++++- .../functions/ext/collocated/registry.py | 2 +- .../functions/ext/collocated/server.py | 24 +- .../functions/ext/collocated/wasm.py | 6 +- singlestoredb/utils/events.py | 2 +- wit/udf.wit | 4 +- 8 files changed, 269 insertions(+), 86 deletions(-) diff --git a/accel.c b/accel.c index 2062fcba8..394578000 100644 --- a/accel.c +++ b/accel.c @@ -4847,61 +4847,73 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k out = new_out; \ } + // Bounds-check macro for input buffer reads +#define CHECK_REMAINING(n) do { \ + if ((size_t)(end - data) < (size_t)(n)) { \ + PyErr_SetString(PyExc_ValueError, "truncated rowdat_1 input"); \ + goto error; \ + } \ +} while(0) + // Main loop: parse input rows, call function, serialize output while (end > data) { py_row = PyTuple_New(colspec_l); if (!py_row) goto error; // Read row ID - row_id = *(int64_t*)data; data += 8; + CHECK_REMAINING(8); + memcpy(&row_id, data, 8); data += 8; // Parse input columns for (i = 0; i < colspec_l; i++) { + CHECK_REMAINING(1); is_null = data[0] == '\x01'; data += 1; - if (is_null) Py_INCREF(Py_None); switch (ctypes[i]) { case MYSQL_TYPE_NULL: - data += 1; - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); break; case MYSQL_TYPE_TINY: - i8 = *(int8_t*)data; data += 1; + CHECK_REMAINING(1); + memcpy(&i8, data, 1); data += 1; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i8))); } break; case -MYSQL_TYPE_TINY: - u8 = *(uint8_t*)data; data += 1; + CHECK_REMAINING(1); + memcpy(&u8, data, 1); data += 1; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u8))); } break; case MYSQL_TYPE_SHORT: - i16 = *(int16_t*)data; data += 2; + CHECK_REMAINING(2); + memcpy(&i16, data, 2); data += 2; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i16))); } break; case -MYSQL_TYPE_SHORT: - u16 = *(uint16_t*)data; data += 2; + CHECK_REMAINING(2); + memcpy(&u16, data, 2); data += 2; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); } @@ -4909,10 +4921,11 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_LONG: case MYSQL_TYPE_INT24: - i32 = *(int32_t*)data; data += 4; + CHECK_REMAINING(4); + memcpy(&i32, data, 4); data += 4; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i32))); } @@ -4920,50 +4933,55 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case -MYSQL_TYPE_LONG: case -MYSQL_TYPE_INT24: - u32 = *(uint32_t*)data; data += 4; + CHECK_REMAINING(4); + memcpy(&u32, data, 4); data += 4; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u32))); } break; case MYSQL_TYPE_LONGLONG: - i64 = *(int64_t*)data; data += 8; + CHECK_REMAINING(8); + memcpy(&i64, data, 8); data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLongLong((long long)i64))); } break; case -MYSQL_TYPE_LONGLONG: - u64 = *(uint64_t*)data; data += 8; + CHECK_REMAINING(8); + memcpy(&u64, data, 8); data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLongLong((unsigned long long)u64))); } break; case MYSQL_TYPE_FLOAT: - flt = *(float*)data; data += 4; + CHECK_REMAINING(4); + memcpy(&flt, data, 4); data += 4; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)flt))); } break; case MYSQL_TYPE_DOUBLE: - dbl = *(double*)data; data += 8; + CHECK_REMAINING(8); + memcpy(&dbl, data, 8); data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)dbl))); } @@ -4971,31 +4989,37 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_DECIMAL: case MYSQL_TYPE_NEWDECIMAL: - // TODO - break; + PyErr_SetString(PyExc_NotImplementedError, + "DECIMAL type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_DATE: case MYSQL_TYPE_NEWDATE: - // TODO - break; + PyErr_SetString(PyExc_NotImplementedError, + "DATE type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_TIME: - // TODO - break; + PyErr_SetString(PyExc_NotImplementedError, + "TIME type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_DATETIME: - // TODO - break; + PyErr_SetString(PyExc_NotImplementedError, + "DATETIME type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_TIMESTAMP: - // TODO - break; + PyErr_SetString(PyExc_NotImplementedError, + "TIMESTAMP type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_YEAR: - u16 = *(uint16_t*)data; data += 2; + CHECK_REMAINING(2); + memcpy(&u16, data, 2); data += 2; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); } @@ -5012,11 +5036,13 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_MEDIUM_BLOB: case MYSQL_TYPE_LONG_BLOB: case MYSQL_TYPE_BLOB: - i64 = *(int64_t*)data; data += 8; + CHECK_REMAINING(8); + memcpy(&i64, data, 8); data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { + CHECK_REMAINING((size_t)i64); py_str = PyUnicode_FromStringAndSize(data, (Py_ssize_t)i64); data += i64; if (!py_str) goto error; @@ -5036,11 +5062,13 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case -MYSQL_TYPE_MEDIUM_BLOB: case -MYSQL_TYPE_LONG_BLOB: case -MYSQL_TYPE_BLOB: - i64 = *(int64_t*)data; data += 8; + CHECK_REMAINING(8); + memcpy(&i64, data, 8); data += 8; if (is_null) { - CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); } else { + CHECK_REMAINING((size_t)i64); py_blob = PyBytes_FromStringAndSize(data, (Py_ssize_t)i64); data += i64; if (!py_blob) goto error; @@ -5050,10 +5078,14 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k break; default: + PyErr_Format(PyExc_TypeError, + "unsupported input column type: %d", ctypes[i]); goto error; } } +#undef CHECK_REMAINING + // Call the user function py_result = PyObject_Call(py_func, py_row, NULL); Py_DECREF(py_row); @@ -5086,33 +5118,56 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k switch (rtypes[i]) { case MYSQL_TYPE_BIT: - // TODO - break; + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "BIT type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_TINY: CHECKMEM_CFA(1); - i8 = (is_null) ? 0 : (int8_t)PyLong_AsLong(py_result_item); + if (is_null) { + i8 = 0; + } else { + i8 = (int8_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &i8, 1); out_idx += 1; break; case -MYSQL_TYPE_TINY: CHECKMEM_CFA(1); - u8 = (is_null) ? 0 : (uint8_t)PyLong_AsUnsignedLong(py_result_item); + if (is_null) { + u8 = 0; + } else { + u8 = (uint8_t)PyLong_AsUnsignedLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &u8, 1); out_idx += 1; break; case MYSQL_TYPE_SHORT: CHECKMEM_CFA(2); - i16 = (is_null) ? 0 : (int16_t)PyLong_AsLong(py_result_item); + if (is_null) { + i16 = 0; + } else { + i16 = (int16_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &i16, 2); out_idx += 2; break; case -MYSQL_TYPE_SHORT: CHECKMEM_CFA(2); - u16 = (is_null) ? 0 : (uint16_t)PyLong_AsUnsignedLong(py_result_item); + if (is_null) { + u16 = 0; + } else { + u16 = (uint16_t)PyLong_AsUnsignedLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &u16, 2); out_idx += 2; break; @@ -5120,7 +5175,12 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case MYSQL_TYPE_LONG: case MYSQL_TYPE_INT24: CHECKMEM_CFA(4); - i32 = (is_null) ? 0 : (int32_t)PyLong_AsLong(py_result_item); + if (is_null) { + i32 = 0; + } else { + i32 = (int32_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &i32, 4); out_idx += 4; break; @@ -5128,63 +5188,109 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k case -MYSQL_TYPE_LONG: case -MYSQL_TYPE_INT24: CHECKMEM_CFA(4); - u32 = (is_null) ? 0 : (uint32_t)PyLong_AsUnsignedLong(py_result_item); + if (is_null) { + u32 = 0; + } else { + u32 = (uint32_t)PyLong_AsUnsignedLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &u32, 4); out_idx += 4; break; case MYSQL_TYPE_LONGLONG: CHECKMEM_CFA(8); - i64 = (is_null) ? 0 : (int64_t)PyLong_AsLongLong(py_result_item); + if (is_null) { + i64 = 0; + } else { + i64 = (int64_t)PyLong_AsLongLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &i64, 8); out_idx += 8; break; case -MYSQL_TYPE_LONGLONG: CHECKMEM_CFA(8); - u64 = (is_null) ? 0 : (uint64_t)PyLong_AsUnsignedLongLong(py_result_item); + if (is_null) { + u64 = 0; + } else { + u64 = (uint64_t)PyLong_AsUnsignedLongLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &u64, 8); out_idx += 8; break; case MYSQL_TYPE_FLOAT: CHECKMEM_CFA(4); - flt = (is_null) ? 0 : (float)PyFloat_AsDouble(py_result_item); + if (is_null) { + flt = 0; + } else { + flt = (float)PyFloat_AsDouble(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &flt, 4); out_idx += 4; break; case MYSQL_TYPE_DOUBLE: CHECKMEM_CFA(8); - dbl = (is_null) ? 0 : (double)PyFloat_AsDouble(py_result_item); + if (is_null) { + dbl = 0; + } else { + dbl = (double)PyFloat_AsDouble(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &dbl, 8); out_idx += 8; break; case MYSQL_TYPE_DECIMAL: - // TODO - break; + case MYSQL_TYPE_NEWDECIMAL: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "DECIMAL type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_DATE: case MYSQL_TYPE_NEWDATE: - // TODO - break; + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "DATE type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_TIME: - // TODO - break; + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "TIME type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_DATETIME: - // TODO - break; + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "DATETIME type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_TIMESTAMP: - // TODO - break; + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "TIMESTAMP type not yet supported in call_function_accel"); + goto error; case MYSQL_TYPE_YEAR: CHECKMEM_CFA(2); - i16 = (is_null) ? 0 : (int16_t)PyLong_AsLong(py_result_item); + if (is_null) { + i16 = 0; + } else { + i16 = (int16_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } memcpy(out+out_idx, &i16, 2); out_idx += 2; break; @@ -5209,6 +5315,7 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k py_bytes = PyUnicode_AsEncodedString(py_result_item, "utf-8", "strict"); if (!py_bytes) { Py_DECREF(py_result_item); + py_result_item = NULL; goto error; } @@ -5216,7 +5323,9 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k Py_ssize_t str_l = 0; if (PyBytes_AsStringAndSize(py_bytes, &str, &str_l) < 0) { Py_DECREF(py_bytes); + py_bytes = NULL; Py_DECREF(py_result_item); + py_result_item = NULL; goto error; } @@ -5252,6 +5361,7 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k Py_ssize_t str_l = 0; if (PyBytes_AsStringAndSize(py_result_item, &str, &str_l) < 0) { Py_DECREF(py_result_item); + py_result_item = NULL; goto error; } @@ -5266,6 +5376,9 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k default: Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_Format(PyExc_TypeError, + "unsupported output column type: %d", rtypes[i]); goto error; } diff --git a/singlestoredb/connection.py b/singlestoredb/connection.py index 01d8c2463..01245dd5f 100644 --- a/singlestoredb/connection.py +++ b/singlestoredb/connection.py @@ -1166,12 +1166,13 @@ def _iquery( cur.execute(oper, params) if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I): return [] - out = list(cur.fetchall()) + raw = cur.fetchall() + if hasattr(raw, 'to_dict') and callable(raw.to_dict): + return raw.to_dict(orient='records') + out = list(raw) if not out: return [] - if hasattr(out, 'to_dict') and callable(getattr(out, 'to_dict')): - out = out.to_dict(orient='records') - elif isinstance(out[0], (tuple, list)): + if isinstance(out[0], (tuple, list)): if cur.description: names = [x[0] for x in cur.description] if fix_names: diff --git a/singlestoredb/functions/ext/collocated/connection.py b/singlestoredb/functions/ext/collocated/connection.py index 7c8149d1b..c69bd1c13 100644 --- a/singlestoredb/functions/ext/collocated/connection.py +++ b/singlestoredb/functions/ext/collocated/connection.py @@ -41,6 +41,9 @@ # Pre-pack the status OK header prefix to avoid per-request struct.pack _STATUS_OK_PREFIX = struct.pack(' _MAX_FUNCTION_NAME_LEN: + logger.warning(f'Function name too long: {namelen}') + return + # Receive function name + 2 FDs via SCM_RIGHTS fd_model = array.array('i', [0, 0]) msg, ancdata, flags, addr = conn.recvmsg( namelen, socket.CMSG_LEN(2 * fd_model.itemsize), ) - if len(ancdata) != 1: - logger.warning(f'Expected 1 ancdata, got {len(ancdata)}') - return - function_name = msg.decode('utf8') - input_fd, output_fd = struct.unpack(' bytes | None: return None pos += nbytes return bytes(buf) + + +def _write_all_fd(fd: int, data: bytes) -> None: + """Write all bytes to a file descriptor, handling partial writes.""" + view = memoryview(data) + written = 0 + while written < len(data): + try: + n = os.write(fd, view[written:]) + except InterruptedError: + continue + if n == 0: + raise RuntimeError('short write to output fd') + written += n diff --git a/singlestoredb/functions/ext/collocated/registry.py b/singlestoredb/functions/ext/collocated/registry.py index 822921e16..3dadcbf29 100644 --- a/singlestoredb/functions/ext/collocated/registry.py +++ b/singlestoredb/functions/ext/collocated/registry.py @@ -458,7 +458,7 @@ def call_function( results = [] for row in rows: result = func(*row) - if not isinstance(result, tuple): + if not isinstance(result, (tuple, list)): result = [result] results.append(list(result)) return bytes(_dump_rowdat_1(return_types, row_ids, results)) diff --git a/singlestoredb/functions/ext/collocated/server.py b/singlestoredb/functions/ext/collocated/server.py index fdb0b75d4..60e31d174 100644 --- a/singlestoredb/functions/ext/collocated/server.py +++ b/singlestoredb/functions/ext/collocated/server.py @@ -24,6 +24,7 @@ from typing import Optional from typing import Tuple +from .connection import _write_all_fd from .connection import handle_connection from .registry import FunctionRegistry @@ -61,7 +62,7 @@ def _write_pipe_message(fd: int, payload: bytes) -> None: Wire format: [u32 LE length][payload]. """ header = struct.pack(' None: - """Register a function from its signature and Python source code.""" + """Register a function from its signature and function body. + + The ``code`` parameter should contain the function body, not a + full ``def`` statement or ``@udf``-decorated source. + """ try: _registry.create_function(signature, code, replace) except Exception as e: diff --git a/singlestoredb/utils/events.py b/singlestoredb/utils/events.py index 2054b09d1..1a0c66443 100644 --- a/singlestoredb/utils/events.py +++ b/singlestoredb/utils/events.py @@ -7,7 +7,7 @@ try: from IPython import get_ipython has_ipython = True -except Exception: +except (ImportError, OSError): has_ipython = False diff --git a/wit/udf.wit b/wit/udf.wit index e06e35a41..6362b5d62 100644 --- a/wit/udf.wit +++ b/wit/udf.wit @@ -14,9 +14,9 @@ interface function-handler { /// args_data_format, returns_data_format, function_type, doc describe-functions: func() -> result; - /// Register a function from its signature and Python source code. + /// Register a function from its signature and source code. /// `signature` is a JSON object matching the describe-functions element schema. - /// `code` is the Python source containing the @udf-decorated function. + /// `code` is the function body (not a full `def` statement). /// `replace` controls whether an existing function of the same name is overwritten. create-function: func(signature: string, code: string, replace: bool) -> result<_, string>; } From fe1d8adf3700984e846a2cdb30a08e5b11803ce3 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 2 Apr 2026 11:10:02 -0500 Subject: [PATCH 18/19] Fix recv_exact protocol desync and unchecked PyObject_Length returns Guard against protocol desynchronization when poll() times out after partial data has been consumed from the socket. In the C path (accel_recv_exact), switch to blocking mode when pos > 0 so the message is always completed. Apply the same fix to the Python fallback (_recv_exact_py) by catching TimeoutError mid-read and removing the socket timeout. Add error checking at all PyObject_Length call sites that cast the result to unsigned. PyObject_Length returns -1 on error, which when cast to unsigned long long produces ULLONG_MAX, leading to massive malloc allocations or out-of-bounds access. Each site now checks for < 0 and gotos error before casting. Co-Authored-By: Claude Opus 4.6 --- accel.c | 84 ++++++++++++++----- .../functions/ext/collocated/connection.py | 10 ++- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/accel.c b/accel.c index 394578000..e66b93f93 100644 --- a/accel.c +++ b/accel.c @@ -2276,7 +2276,11 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k orig_data = data; // Get number of columns - n_cols = PyObject_Length(py_colspec); + { + Py_ssize_t tmp = PyObject_Length(py_colspec); + if (tmp < 0) goto error; + n_cols = (unsigned long long)tmp; + } // Determine column types ctypes = calloc(sizeof(int), n_cols); @@ -2920,19 +2924,27 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k goto error; } - if (PyObject_Length(py_returns) != PyObject_Length(py_cols)) { - PyErr_SetString(PyExc_ValueError, "number of return values does not match number of returned columns"); - goto error; + { + Py_ssize_t tmp_returns_l = PyObject_Length(py_returns); + if (tmp_returns_l < 0) goto error; + Py_ssize_t tmp_cols_l = PyObject_Length(py_cols); + if (tmp_cols_l < 0) goto error; + if (tmp_returns_l != tmp_cols_l) { + PyErr_SetString(PyExc_ValueError, "number of return values does not match number of returned columns"); + goto error; + } + n_cols = (unsigned long long)tmp_returns_l; } - n_rows = (unsigned long long)PyObject_Length(py_row_ids); + { + Py_ssize_t tmp = PyObject_Length(py_row_ids); + if (tmp < 0) goto error; + n_rows = (unsigned long long)tmp; + } if (n_rows == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; } - - // Verify all data lengths agree - n_cols = (unsigned long long)PyObject_Length(py_returns); if (n_cols == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; @@ -2944,17 +2956,25 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k PyObject *py_data = PyTuple_GetItem(py_item, 0); if (!py_data) goto error; - if ((unsigned long long)PyObject_Length(py_data) != n_rows) { - PyErr_SetString(PyExc_ValueError, "mismatched lengths of column values"); - goto error; + { + Py_ssize_t tmp = PyObject_Length(py_data); + if (tmp < 0) goto error; + if ((unsigned long long)tmp != n_rows) { + PyErr_SetString(PyExc_ValueError, "mismatched lengths of column values"); + goto error; + } } PyObject *py_mask = PyTuple_GetItem(py_item, 1); if (!py_mask) goto error; - if (py_mask != Py_None && (unsigned long long)PyObject_Length(py_mask) != n_rows) { - PyErr_SetString(PyExc_ValueError, "length of mask values does not match the length of data rows"); - goto error; + if (py_mask != Py_None) { + Py_ssize_t tmp = PyObject_Length(py_mask); + if (tmp < 0) goto error; + if ((unsigned long long)tmp != n_rows) { + PyErr_SetString(PyExc_ValueError, "length of mask values does not match the length of data rows"); + goto error; + } } } @@ -4179,7 +4199,11 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) CHECKRC(PyBytes_AsStringAndSize(py_data, &data, &length)); end = data + (unsigned long long)length; - colspec_l = PyObject_Length(py_colspec); + { + Py_ssize_t tmp = PyObject_Length(py_colspec); + if (tmp < 0) goto error; + colspec_l = (unsigned long long)tmp; + } ctypes = malloc(sizeof(int) * colspec_l); for (i = 0; i < colspec_l; i++) { @@ -4481,7 +4505,11 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } - n_rows = (unsigned long long)PyObject_Length(py_rows); + { + Py_ssize_t tmp = PyObject_Length(py_rows); + if (tmp < 0) goto error; + n_rows = (unsigned long long)tmp; + } if (n_rows == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; @@ -4494,7 +4522,11 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) if (!out) goto error; // Get return types - n_cols = (unsigned long long)PyObject_Length(py_returns); + { + Py_ssize_t tmp = PyObject_Length(py_returns); + if (tmp < 0) goto error; + n_cols = (unsigned long long)tmp; + } if (n_cols == 0) { PyErr_SetString(PyExc_ValueError, "no return values specified"); goto error; @@ -4809,7 +4841,11 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k if (length == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; } // Parse colspec types - colspec_l = (unsigned long long)PyObject_Length(py_colspec); + { + Py_ssize_t tmp = PyObject_Length(py_colspec); + if (tmp < 0) goto error; + colspec_l = (unsigned long long)tmp; + } ctypes = malloc(sizeof(int) * colspec_l); if (!ctypes) goto error; for (i = 0; i < colspec_l; i++) { @@ -4822,7 +4858,11 @@ static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *k } // Parse return types - returns_l = (unsigned long long)PyObject_Length(py_returns); + { + Py_ssize_t tmp = PyObject_Length(py_returns); + if (tmp < 0) goto error; + returns_l = (unsigned long long)tmp; + } rtypes = malloc(sizeof(int) * returns_l); if (!rtypes) goto error; for (i = 0; i < returns_l; i++) { @@ -5529,6 +5569,12 @@ static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { poll_rc = poll(&pfd, 1, timeout_ms); Py_END_ALLOW_THREADS if (poll_rc == 0) { + if (pos > 0) { + /* Partial message already consumed — must finish it. + Block indefinitely to avoid protocol desync. */ + timeout_ms = -1; + continue; + } free(buf); PyErr_SetString(PyExc_TimeoutError, "recv_exact timed out"); return NULL; diff --git a/singlestoredb/functions/ext/collocated/connection.py b/singlestoredb/functions/ext/collocated/connection.py index c69bd1c13..5d0256404 100644 --- a/singlestoredb/functions/ext/collocated/connection.py +++ b/singlestoredb/functions/ext/collocated/connection.py @@ -395,7 +395,15 @@ def _recv_exact_py(sock: socket.socket, n: int) -> bytes | None: view = memoryview(buf) pos = 0 while pos < n: - nbytes = sock.recv_into(view[pos:]) + try: + nbytes = sock.recv_into(view[pos:]) + except TimeoutError: + if pos == 0: + raise + # Partial message already consumed — must finish it. + # Remove timeout to avoid protocol desync. + sock.settimeout(None) + continue if nbytes == 0: return None pos += nbytes From 548edcc4a7bd16bc83ad579a02ba363e23992ade Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 2 Apr 2026 11:37:08 -0500 Subject: [PATCH 19/19] Fix _iquery DataFrame conversion for non-tuple results_type _iquery must always return List[Dict[str, Any]], but when the connection uses a non-tuple results_type (polars, pandas, numpy, arrow), the specialized cursor's fetchall() returns a DataFrame/ndarray instead of tuples. The previous code had two bugs: 1. list() on a DataFrame iterates by columns, producing Series objects instead of row dicts. 2. to_dict(orient='records') is pandas-specific and fails on polars. Dispatch on the raw fetchall() result type before converting to dicts: - pandas DataFrame: to_dict(orient='records') - polars DataFrame: to_dicts() - Arrow Table: to_pydict() with column-to-row transposition - numpy ndarray: tolist() with cursor.description column names - tuples/dicts: existing logic preserved Centralize fix_names camelCase conversion as a single post-processing step applied uniformly to all result types. Co-Authored-By: Claude Opus 4.6 --- singlestoredb/connection.py | 57 +++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/singlestoredb/connection.py b/singlestoredb/connection.py index 01245dd5f..6314debb1 100644 --- a/singlestoredb/connection.py +++ b/singlestoredb/connection.py @@ -1167,17 +1167,58 @@ def _iquery( if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I): return [] raw = cur.fetchall() - if hasattr(raw, 'to_dict') and callable(raw.to_dict): - return raw.to_dict(orient='records') - out = list(raw) - if not out: + if raw is None: return [] - if isinstance(out[0], (tuple, list)): + # pandas DataFrame + if hasattr(raw, 'to_dict') and hasattr(raw, 'columns'): + out = raw.to_dict(orient='records') + # polars DataFrame + elif hasattr(raw, 'to_dicts') and callable(raw.to_dicts): + out = raw.to_dicts() + # arrow Table + elif hasattr(raw, 'to_pydict') and callable(raw.to_pydict): + d = raw.to_pydict() + cols = list(d.keys()) + n = len(next(iter(d.values()))) if d else 0 + out = [{c: d[c][i] for c in cols} for i in range(n)] + # numpy ndarray + elif hasattr(raw, 'tolist') and hasattr(raw, 'ndim'): + rows = raw.tolist() if cur.description: names = [x[0] for x in cur.description] - if fix_names: - names = [under2camel(str(x).replace(' ', '')) for x in names] - out = [{k: v for k, v in zip(names, row)} for row in out] + out = [ + {k: v for k, v in zip(names, row)} + for row in rows + ] + else: + return [] + # list of tuples/namedtuples/dicts + else: + out = list(raw) + if not out: + return [] + if isinstance(out[0], dict): + pass # already dicts + elif isinstance(out[0], (tuple, list)): + if cur.description: + names = [x[0] for x in cur.description] + out = [ + {k: v for k, v in zip(names, row)} + for row in out + ] + else: + return [] + if not out: + return [] + # Apply camelCase name conversion if requested + if fix_names: + out = [ + { + under2camel(str(k).replace(' ', '')): v + for k, v in row.items() + } + for row in out + ] return out @abc.abstractmethod