diff --git a/accel.c b/accel.c index ee9a72f3..e66b93f9 100644 --- a/accel.c +++ b/accel.c @@ -1,7 +1,16 @@ #include +#ifndef __wasi__ +#include +#endif #include #include +#include +#ifndef __wasi__ +#include +#include +#endif +#include #include #ifndef Py_LIMITED_API @@ -2267,7 +2276,11 @@ static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k orig_data = data; // Get number of columns - n_cols = PyObject_Length(py_colspec); + { + Py_ssize_t tmp = PyObject_Length(py_colspec); + if (tmp < 0) goto error; + n_cols = (unsigned long long)tmp; + } // Determine column types ctypes = calloc(sizeof(int), n_cols); @@ -2911,19 +2924,27 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k goto error; } - if (PyObject_Length(py_returns) != PyObject_Length(py_cols)) { - PyErr_SetString(PyExc_ValueError, "number of return values does not match number of returned columns"); - goto error; + { + Py_ssize_t tmp_returns_l = PyObject_Length(py_returns); + if (tmp_returns_l < 0) goto error; + Py_ssize_t tmp_cols_l = PyObject_Length(py_cols); + if (tmp_cols_l < 0) goto error; + if (tmp_returns_l != tmp_cols_l) { + PyErr_SetString(PyExc_ValueError, "number of return values does not match number of returned columns"); + goto error; + } + n_cols = (unsigned long long)tmp_returns_l; } - n_rows = (unsigned long long)PyObject_Length(py_row_ids); + { + Py_ssize_t tmp = PyObject_Length(py_row_ids); + if (tmp < 0) goto error; + n_rows = (unsigned long long)tmp; + } if (n_rows == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; } - - // Verify all data lengths agree - n_cols = (unsigned long long)PyObject_Length(py_returns); if (n_cols == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; @@ -2935,17 +2956,25 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k PyObject *py_data = PyTuple_GetItem(py_item, 0); if (!py_data) goto error; - if ((unsigned long long)PyObject_Length(py_data) != n_rows) { - PyErr_SetString(PyExc_ValueError, "mismatched lengths of column values"); - goto error; + { + Py_ssize_t tmp = PyObject_Length(py_data); + if (tmp < 0) goto error; + if ((unsigned long long)tmp != n_rows) { + PyErr_SetString(PyExc_ValueError, "mismatched lengths of column values"); + goto error; + } } PyObject *py_mask = PyTuple_GetItem(py_item, 1); if (!py_mask) goto error; - if (py_mask != Py_None && (unsigned long long)PyObject_Length(py_mask) != n_rows) { - PyErr_SetString(PyExc_ValueError, "length of mask values does not match the length of data rows"); - goto error; + if (py_mask != Py_None) { + Py_ssize_t tmp = PyObject_Length(py_mask); + if (tmp < 0) goto error; + if ((unsigned long long)tmp != n_rows) { + PyErr_SetString(PyExc_ValueError, "length of mask values does not match the length of data rows"); + goto error; + } } } @@ -4170,7 +4199,11 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) CHECKRC(PyBytes_AsStringAndSize(py_data, &data, &length)); end = data + (unsigned long long)length; - colspec_l = PyObject_Length(py_colspec); + { + Py_ssize_t tmp = PyObject_Length(py_colspec); + if (tmp < 0) goto error; + colspec_l = (unsigned long long)tmp; + } ctypes = malloc(sizeof(int) * colspec_l); for (i = 0; i < colspec_l; i++) { @@ -4472,7 +4505,11 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) goto error; } - n_rows = (unsigned long long)PyObject_Length(py_rows); + { + Py_ssize_t tmp = PyObject_Length(py_rows); + if (tmp < 0) goto error; + n_rows = (unsigned long long)tmp; + } if (n_rows == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; @@ -4485,7 +4522,11 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) if (!out) goto error; // Get return types - n_cols = (unsigned long long)PyObject_Length(py_returns); + { + Py_ssize_t tmp = PyObject_Length(py_returns); + if (tmp < 0) goto error; + n_cols = (unsigned long long)tmp; + } if (n_cols == 0) { PyErr_SetString(PyExc_ValueError, "no return values specified"); goto error; @@ -4774,12 +4815,828 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs) } +static PyObject *call_function_accel(PyObject *self, PyObject *args, PyObject *kwargs) { + PyObject *py_colspec = NULL, *py_returns = NULL, *py_data = NULL, *py_func = NULL; + PyObject *py_out = NULL, *py_row = NULL, *py_result = NULL, *py_result_item = NULL; + PyObject *py_str = NULL, *py_blob = NULL, *py_bytes = NULL; + Py_ssize_t length = 0; + uint64_t row_id = 0; + uint8_t is_null = 0; + int8_t i8 = 0; int16_t i16 = 0; int32_t i32 = 0; int64_t i64 = 0; + uint8_t u8 = 0; uint16_t u16 = 0; uint32_t u32 = 0; uint64_t u64 = 0; + float flt = 0; double dbl = 0; + int *ctypes = NULL, *rtypes = NULL; + char *data = NULL, *end = NULL, *out = NULL; + unsigned long long out_l = 0, out_idx = 0, colspec_l = 0, returns_l = 0, i = 0; + char *keywords[] = {"colspec", "returns", "data", "func", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OOOO", keywords, + &py_colspec, &py_returns, &py_data, &py_func)) goto error; + if (!PyCallable_Check(py_func)) { + PyErr_SetString(PyExc_TypeError, "func must be callable"); goto error; + } + + CHECKRC(PyBytes_AsStringAndSize(py_data, &data, &length)); + end = data + (unsigned long long)length; + if (length == 0) { py_out = PyBytes_FromStringAndSize("", 0); goto exit; } + + // Parse colspec types + { + Py_ssize_t tmp = PyObject_Length(py_colspec); + if (tmp < 0) goto error; + colspec_l = (unsigned long long)tmp; + } + ctypes = malloc(sizeof(int) * colspec_l); + if (!ctypes) goto error; + for (i = 0; i < colspec_l; i++) { + PyObject *py_cspec = PySequence_GetItem(py_colspec, i); + if (!py_cspec) goto error; + PyObject *py_ctype = PySequence_GetItem(py_cspec, 1); + if (!py_ctype) { Py_DECREF(py_cspec); goto error; } + ctypes[i] = (int)PyLong_AsLong(py_ctype); + Py_DECREF(py_ctype); Py_DECREF(py_cspec); + } + + // Parse return types + { + Py_ssize_t tmp = PyObject_Length(py_returns); + if (tmp < 0) goto error; + returns_l = (unsigned long long)tmp; + } + rtypes = malloc(sizeof(int) * returns_l); + if (!rtypes) goto error; + for (i = 0; i < returns_l; i++) { + PyObject *py_item = PySequence_GetItem(py_returns, i); + if (!py_item) goto error; + rtypes[i] = (int)PyLong_AsLong(py_item); + Py_DECREF(py_item); + } + + out_l = 256; + out = malloc(out_l); + if (!out) goto error; + +#define CHECKMEM_CFA(x) \ + if ((out_idx + (x)) > out_l) { \ + out_l = out_l * 2 + (x); \ + char *new_out = realloc(out, out_l); \ + if (!new_out) { \ + PyErr_SetString(PyExc_MemoryError, "failed to reallocate output buffer"); \ + goto error; \ + } \ + out = new_out; \ + } + + // Bounds-check macro for input buffer reads +#define CHECK_REMAINING(n) do { \ + if ((size_t)(end - data) < (size_t)(n)) { \ + PyErr_SetString(PyExc_ValueError, "truncated rowdat_1 input"); \ + goto error; \ + } \ +} while(0) + + // Main loop: parse input rows, call function, serialize output + while (end > data) { + py_row = PyTuple_New(colspec_l); + if (!py_row) goto error; + + // Read row ID + CHECK_REMAINING(8); + memcpy(&row_id, data, 8); data += 8; + + // Parse input columns + for (i = 0; i < colspec_l; i++) { + CHECK_REMAINING(1); + is_null = data[0] == '\x01'; data += 1; + + switch (ctypes[i]) { + case MYSQL_TYPE_NULL: + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + break; + + case MYSQL_TYPE_TINY: + CHECK_REMAINING(1); + memcpy(&i8, data, 1); data += 1; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i8))); + } + break; + + case -MYSQL_TYPE_TINY: + CHECK_REMAINING(1); + memcpy(&u8, data, 1); data += 1; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u8))); + } + break; + + case MYSQL_TYPE_SHORT: + CHECK_REMAINING(2); + memcpy(&i16, data, 2); data += 2; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i16))); + } + break; + + case -MYSQL_TYPE_SHORT: + CHECK_REMAINING(2); + memcpy(&u16, data, 2); data += 2; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); + } + break; + + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + CHECK_REMAINING(4); + memcpy(&i32, data, 4); data += 4; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLong((long)i32))); + } + break; + + case -MYSQL_TYPE_LONG: + case -MYSQL_TYPE_INT24: + CHECK_REMAINING(4); + memcpy(&u32, data, 4); data += 4; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u32))); + } + break; + + case MYSQL_TYPE_LONGLONG: + CHECK_REMAINING(8); + memcpy(&i64, data, 8); data += 8; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromLongLong((long long)i64))); + } + break; + + case -MYSQL_TYPE_LONGLONG: + CHECK_REMAINING(8); + memcpy(&u64, data, 8); data += 8; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLongLong((unsigned long long)u64))); + } + break; + + case MYSQL_TYPE_FLOAT: + CHECK_REMAINING(4); + memcpy(&flt, data, 4); data += 4; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)flt))); + } + break; + + case MYSQL_TYPE_DOUBLE: + CHECK_REMAINING(8); + memcpy(&dbl, data, 8); data += 8; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyFloat_FromDouble((double)dbl))); + } + break; + + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_NEWDECIMAL: + PyErr_SetString(PyExc_NotImplementedError, + "DECIMAL type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_NEWDATE: + PyErr_SetString(PyExc_NotImplementedError, + "DATE type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_TIME: + PyErr_SetString(PyExc_NotImplementedError, + "TIME type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_DATETIME: + PyErr_SetString(PyExc_NotImplementedError, + "DATETIME type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_TIMESTAMP: + PyErr_SetString(PyExc_NotImplementedError, + "TIMESTAMP type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_YEAR: + CHECK_REMAINING(2); + memcpy(&u16, data, 2); data += 2; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECKRC(PyTuple_SetItem(py_row, i, PyLong_FromUnsignedLong((unsigned long)u16))); + } + break; + + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_JSON: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_ENUM: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_GEOMETRY: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + CHECK_REMAINING(8); + memcpy(&i64, data, 8); data += 8; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECK_REMAINING((size_t)i64); + py_str = PyUnicode_FromStringAndSize(data, (Py_ssize_t)i64); + data += i64; + if (!py_str) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_str)); + py_str = NULL; + } + break; + + case -MYSQL_TYPE_VARCHAR: + case -MYSQL_TYPE_JSON: + case -MYSQL_TYPE_SET: + case -MYSQL_TYPE_ENUM: + case -MYSQL_TYPE_VAR_STRING: + case -MYSQL_TYPE_STRING: + case -MYSQL_TYPE_GEOMETRY: + case -MYSQL_TYPE_TINY_BLOB: + case -MYSQL_TYPE_MEDIUM_BLOB: + case -MYSQL_TYPE_LONG_BLOB: + case -MYSQL_TYPE_BLOB: + CHECK_REMAINING(8); + memcpy(&i64, data, 8); data += 8; + if (is_null) { + Py_INCREF(Py_None); + CHECKRC(PyTuple_SetItem(py_row, i, Py_None)); + } else { + CHECK_REMAINING((size_t)i64); + py_blob = PyBytes_FromStringAndSize(data, (Py_ssize_t)i64); + data += i64; + if (!py_blob) goto error; + CHECKRC(PyTuple_SetItem(py_row, i, py_blob)); + py_blob = NULL; + } + break; + + default: + PyErr_Format(PyExc_TypeError, + "unsupported input column type: %d", ctypes[i]); + goto error; + } + } + +#undef CHECK_REMAINING + + // Call the user function + py_result = PyObject_Call(py_func, py_row, NULL); + Py_DECREF(py_row); + py_row = NULL; + if (!py_result) goto error; + + // Normalize result: wrap scalar in a tuple + if (!PyList_Check(py_result) && !PyTuple_Check(py_result)) { + PyObject *py_wrapped = PyTuple_Pack(1, py_result); + Py_DECREF(py_result); + py_result = py_wrapped; + if (!py_result) goto error; + } + + // Write row ID to output + CHECKMEM_CFA(8); + memcpy(out+out_idx, &row_id, 8); + out_idx += 8; + + // Serialize output columns + for (i = 0; i < returns_l; i++) { + py_result_item = PySequence_GetItem(py_result, i); + if (!py_result_item) goto error; + + is_null = (uint8_t)(py_result_item == Py_None); + + CHECKMEM_CFA(1); + memcpy(out+out_idx, &is_null, 1); + out_idx += 1; + + switch (rtypes[i]) { + case MYSQL_TYPE_BIT: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "BIT type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_TINY: + CHECKMEM_CFA(1); + if (is_null) { + i8 = 0; + } else { + i8 = (int8_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &i8, 1); + out_idx += 1; + break; + + case -MYSQL_TYPE_TINY: + CHECKMEM_CFA(1); + if (is_null) { + u8 = 0; + } else { + u8 = (uint8_t)PyLong_AsUnsignedLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &u8, 1); + out_idx += 1; + break; + + case MYSQL_TYPE_SHORT: + CHECKMEM_CFA(2); + if (is_null) { + i16 = 0; + } else { + i16 = (int16_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &i16, 2); + out_idx += 2; + break; + + case -MYSQL_TYPE_SHORT: + CHECKMEM_CFA(2); + if (is_null) { + u16 = 0; + } else { + u16 = (uint16_t)PyLong_AsUnsignedLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &u16, 2); + out_idx += 2; + break; + + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + CHECKMEM_CFA(4); + if (is_null) { + i32 = 0; + } else { + i32 = (int32_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &i32, 4); + out_idx += 4; + break; + + case -MYSQL_TYPE_LONG: + case -MYSQL_TYPE_INT24: + CHECKMEM_CFA(4); + if (is_null) { + u32 = 0; + } else { + u32 = (uint32_t)PyLong_AsUnsignedLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &u32, 4); + out_idx += 4; + break; + + case MYSQL_TYPE_LONGLONG: + CHECKMEM_CFA(8); + if (is_null) { + i64 = 0; + } else { + i64 = (int64_t)PyLong_AsLongLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + break; + + case -MYSQL_TYPE_LONGLONG: + CHECKMEM_CFA(8); + if (is_null) { + u64 = 0; + } else { + u64 = (uint64_t)PyLong_AsUnsignedLongLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &u64, 8); + out_idx += 8; + break; + + case MYSQL_TYPE_FLOAT: + CHECKMEM_CFA(4); + if (is_null) { + flt = 0; + } else { + flt = (float)PyFloat_AsDouble(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &flt, 4); + out_idx += 4; + break; + + case MYSQL_TYPE_DOUBLE: + CHECKMEM_CFA(8); + if (is_null) { + dbl = 0; + } else { + dbl = (double)PyFloat_AsDouble(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &dbl, 8); + out_idx += 8; + break; + + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_NEWDECIMAL: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "DECIMAL type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_NEWDATE: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "DATE type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_TIME: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "TIME type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_DATETIME: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "DATETIME type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_TIMESTAMP: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_SetString(PyExc_NotImplementedError, + "TIMESTAMP type not yet supported in call_function_accel"); + goto error; + + case MYSQL_TYPE_YEAR: + CHECKMEM_CFA(2); + if (is_null) { + i16 = 0; + } else { + i16 = (int16_t)PyLong_AsLong(py_result_item); + if (PyErr_Occurred()) { Py_DECREF(py_result_item); py_result_item = NULL; goto error; } + } + memcpy(out+out_idx, &i16, 2); + out_idx += 2; + break; + + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_JSON: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_ENUM: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_GEOMETRY: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + if (is_null) { + CHECKMEM_CFA(8); + i64 = 0; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + } else { + py_bytes = PyUnicode_AsEncodedString(py_result_item, "utf-8", "strict"); + if (!py_bytes) { + Py_DECREF(py_result_item); + py_result_item = NULL; + goto error; + } + + char *str = NULL; + Py_ssize_t str_l = 0; + if (PyBytes_AsStringAndSize(py_bytes, &str, &str_l) < 0) { + Py_DECREF(py_bytes); + py_bytes = NULL; + Py_DECREF(py_result_item); + py_result_item = NULL; + goto error; + } + + CHECKMEM_CFA(8+str_l); + i64 = str_l; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + memcpy(out+out_idx, str, str_l); + out_idx += str_l; + Py_DECREF(py_bytes); + py_bytes = NULL; + } + break; + + case -MYSQL_TYPE_VARCHAR: + case -MYSQL_TYPE_JSON: + case -MYSQL_TYPE_SET: + case -MYSQL_TYPE_ENUM: + case -MYSQL_TYPE_VAR_STRING: + case -MYSQL_TYPE_STRING: + case -MYSQL_TYPE_GEOMETRY: + case -MYSQL_TYPE_TINY_BLOB: + case -MYSQL_TYPE_MEDIUM_BLOB: + case -MYSQL_TYPE_LONG_BLOB: + case -MYSQL_TYPE_BLOB: + if (is_null) { + CHECKMEM_CFA(8); + i64 = 0; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + } else { + char *str = NULL; + Py_ssize_t str_l = 0; + if (PyBytes_AsStringAndSize(py_result_item, &str, &str_l) < 0) { + Py_DECREF(py_result_item); + py_result_item = NULL; + goto error; + } + + CHECKMEM_CFA(8+str_l); + i64 = str_l; + memcpy(out+out_idx, &i64, 8); + out_idx += 8; + memcpy(out+out_idx, str, str_l); + out_idx += str_l; + } + break; + + default: + Py_DECREF(py_result_item); + py_result_item = NULL; + PyErr_Format(PyExc_TypeError, + "unsupported output column type: %d", rtypes[i]); + goto error; + } + + Py_DECREF(py_result_item); + py_result_item = NULL; + } + + Py_DECREF(py_result); + py_result = NULL; + } + +#undef CHECKMEM_CFA + + py_out = PyBytes_FromStringAndSize(out, out_idx); + +exit: + if (out) free(out); + if (ctypes) free(ctypes); + if (rtypes) free(rtypes); + + Py_XDECREF(py_row); + Py_XDECREF(py_result); + Py_XDECREF(py_result_item); + Py_XDECREF(py_str); + Py_XDECREF(py_blob); + Py_XDECREF(py_bytes); + + return py_out; + +error: + Py_XDECREF(py_out); + py_out = NULL; + + goto exit; +} + +#ifndef __wasi__ +/* + * mmap_read(fd, length) -> bytes + * + * Maps the given fd with MAP_SHARED|PROT_READ for `length` bytes, + * copies into a Python bytes object, and unmaps in a single C call. + * Eliminates Python mmap object creation/destruction overhead. + */ +static PyObject *accel_mmap_read(PyObject *self, PyObject *args) { + int fd; + Py_ssize_t length; + + if (!PyArg_ParseTuple(args, "in", &fd, &length)) + return NULL; + + if (length <= 0) { + return PyBytes_FromStringAndSize(NULL, 0); + } + + void *addr = mmap(NULL, (size_t)length, PROT_READ, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + + PyObject *result = PyBytes_FromStringAndSize((const char *)addr, length); + munmap(addr, (size_t)length); + return result; +} + +/* + * mmap_write(fd, data, min_size) -> None + * + * Writes `data` to the file descriptor, combining ftruncate + lseek + write + * into a single C call. If min_size > 0, ftruncate is called with + * max(min_size, len(data)); if min_size == 0, ftruncate is skipped + * (caller manages file size). + */ +static PyObject *accel_mmap_write(PyObject *self, PyObject *args) { + int fd; + const char *data; + Py_ssize_t data_len; + Py_ssize_t min_size; + + if (!PyArg_ParseTuple(args, "iy#n", &fd, &data, &data_len, &min_size)) + return NULL; + + if (min_size > 0) { + Py_ssize_t trunc_size = data_len > min_size ? data_len : min_size; + if (ftruncate(fd, (off_t)trunc_size) < 0) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + } + + if (lseek(fd, 0, SEEK_SET) < 0) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + + const char *p = data; + Py_ssize_t remaining = data_len; + while (remaining > 0) { + ssize_t written = write(fd, p, (size_t)remaining); + if (written < 0) { + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + p += written; + remaining -= written; + } + + Py_RETURN_NONE; +} + +/* + * recv_exact(fd, n, timeout_ms=-1) -> bytes or None + * + * Receives exactly `n` bytes from a socket fd using blocking recv. + * Returns None on EOF (peer closed). Operates on raw fd to avoid + * Python socket object overhead. Releases the GIL during recv. + * + * When timeout_ms >= 0, uses poll() before each recv() to wait for + * data with a timeout. Raises TimeoutError on timeout. This allows + * the fd to remain in blocking mode while still supporting timeouts, + * avoiding the interaction between Python's settimeout() (which sets + * O_NONBLOCK) and direct fd-level recv(). + */ +static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { + int fd, timeout_ms = -1; + Py_ssize_t n; + + if (!PyArg_ParseTuple(args, "in|i", &fd, &n, &timeout_ms)) + return NULL; + + if (n <= 0) { + return PyBytes_FromStringAndSize(NULL, 0); + } + + char *buf = (char *)malloc((size_t)n); + if (!buf) { + PyErr_NoMemory(); + return NULL; + } + + Py_ssize_t pos = 0; + while (pos < n) { + if (timeout_ms >= 0) { + struct pollfd pfd = {fd, POLLIN, 0}; + int poll_rc; + Py_BEGIN_ALLOW_THREADS + poll_rc = poll(&pfd, 1, timeout_ms); + Py_END_ALLOW_THREADS + if (poll_rc == 0) { + if (pos > 0) { + /* Partial message already consumed — must finish it. + Block indefinitely to avoid protocol desync. */ + timeout_ms = -1; + continue; + } + free(buf); + PyErr_SetString(PyExc_TimeoutError, "recv_exact timed out"); + return NULL; + } + if (poll_rc < 0) { + free(buf); + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + } + + ssize_t received; + Py_BEGIN_ALLOW_THREADS + received = recv(fd, buf + pos, (size_t)(n - pos), 0); + Py_END_ALLOW_THREADS + + if (received < 0) { + free(buf); + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + if (received == 0) { + /* EOF */ + free(buf); + Py_RETURN_NONE; + } + pos += received; + } + + PyObject *result = PyBytes_FromStringAndSize(buf, n); + free(buf); + return result; +} +#else /* __wasi__ stubs — importable but raise NotImplementedError if called */ + +static PyObject *accel_mmap_read(PyObject *self, PyObject *args) { + PyErr_SetString(PyExc_NotImplementedError, "mmap_read is not available in WASM"); + return NULL; +} + +static PyObject *accel_mmap_write(PyObject *self, PyObject *args) { + PyErr_SetString(PyExc_NotImplementedError, "mmap_write is not available in WASM"); + return NULL; +} + +static PyObject *accel_recv_exact(PyObject *self, PyObject *args) { + PyErr_SetString(PyExc_NotImplementedError, "recv_exact is not available in WASM"); + return NULL; +} + +#endif /* !__wasi__ */ + static PyMethodDef PyMySQLAccelMethods[] = { {"read_rowdata_packet", (PyCFunction)read_rowdata_packet, METH_VARARGS | METH_KEYWORDS, "PyMySQL row data packet reader"}, {"dump_rowdat_1", (PyCFunction)dump_rowdat_1, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions"}, {"load_rowdat_1", (PyCFunction)load_rowdat_1, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions"}, {"dump_rowdat_1_numpy", (PyCFunction)dump_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 formatter for external functions which takes numpy.arrays"}, {"load_rowdat_1_numpy", (PyCFunction)load_rowdat_1_numpy, METH_VARARGS | METH_KEYWORDS, "ROWDAT_1 parser for external functions which creates numpy.arrays"}, + {"call_function_accel", (PyCFunction)call_function_accel, METH_VARARGS | METH_KEYWORDS, "Combined load/call/dump for UDF function calls"}, + {"mmap_read", (PyCFunction)accel_mmap_read, METH_VARARGS, "mmap read: maps fd, copies data, unmaps"}, + {"mmap_write", (PyCFunction)accel_mmap_write, METH_VARARGS, "mmap write: ftruncate+lseek+write in one call"}, + {"recv_exact", (PyCFunction)accel_recv_exact, METH_VARARGS, "Receive exactly N bytes from a socket fd"}, {NULL, NULL, 0, NULL} }; diff --git a/pyproject.toml b/pyproject.toml index 8a8e91fd..c910bce9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,9 @@ dev = [ "singlestoredb[test,docs,build]", ] +[project.scripts] +python-udf-server = "singlestoredb.functions.ext.collocated.__main__:main" + [project.entry-points.pytest11] singlestoredb = "singlestoredb.pytest" diff --git a/resources/build_wasm.sh b/resources/build_wasm.sh new file mode 100755 index 00000000..820acca0 --- /dev/null +++ b/resources/build_wasm.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -eou pipefail + +# CPYTHON_ROOT must contain a build of cpython for wasm32-wasip2 + +TARGET="wasm32-wasip2" +CROSS_BUILD="${CPYTHON_ROOT}/cross-build/${TARGET}" +WASI_SDK_PATH=${WASI_SDK_PATH:-/opt/wasi-sdk} +PYTHON_VERSION=$(grep '^VERSION=' "${CROSS_BUILD}/Makefile" | sed 's/VERSION=[[:space:]]*//') + +if [ ! -e wasm_venv ]; then + uv venv --python ${PYTHON_VERSION} wasm_venv +fi + +. wasm_venv/bin/activate + +HOST_PYTHON=$(which python3) + +uv pip install build wheel cython setuptools + +ARCH_TRIPLET=_wasi_wasm32-wasi + +export CC="${WASI_SDK_PATH}/bin/clang" +export CXX="${WASI_SDK_PATH}/bin/clang++" + +export PYTHONPATH="${CROSS_BUILD}/build/lib.wasi-wasm32-${PYTHON_VERSION}" + +export CFLAGS="--target=${TARGET} -fPIC -I${CROSS_BUILD}/install/include/python${PYTHON_VERSION} -D__EMSCRIPTEN__=1" +export CXXFLAGS="--target=${TARGET} -fPIC -I${CROSS_BUILD}/install/include/python${PYTHON_VERSION}" +export LDSHARED=${CC} +export AR="${WASI_SDK_PATH}/bin/ar" +export RANLIB=true +export LDFLAGS="--target=${TARGET} -shared -Wl,--allow-undefined" +export _PYTHON_SYSCONFIGDATA_NAME=_sysconfigdata__wasi_wasm32-wasi +export _PYTHON_HOST_PLATFORM=wasm32-wasi + +python3 -m build -n -w +wheel unpack --dest build dist/*.whl + +rm -rf ./wasm_venv diff --git a/singlestoredb/auth.py b/singlestoredb/auth.py index 2e10da28..fe94e341 100644 --- a/singlestoredb/auth.py +++ b/singlestoredb/auth.py @@ -5,8 +5,6 @@ from typing import Optional from typing import Union -import jwt - # Credential types PASSWORD = 'password' @@ -42,6 +40,7 @@ def __init__( @classmethod def from_token(cls, token: bytes, verify_signature: bool = False) -> 'JSONWebToken': """Validate the contents of the JWT.""" + import jwt info = jwt.decode(token, options={'verify_signature': verify_signature}) if not info.get('sub', None) and not info.get('username', None): diff --git a/singlestoredb/connection.py b/singlestoredb/connection.py index 942b2feb..6314debb 100644 --- a/singlestoredb/connection.py +++ b/singlestoredb/connection.py @@ -25,12 +25,6 @@ from urllib.parse import urlparse import sqlparams -try: - from pandas import DataFrame -except ImportError: - class DataFrame(object): # type: ignore - def itertuples(self, *args: Any, **kwargs: Any) -> None: - pass from . import auth from . import exceptions @@ -1172,17 +1166,59 @@ def _iquery( cur.execute(oper, params) if not re.match(r'^\s*(select|show|call|echo)\s+', oper, flags=re.I): return [] - out = list(cur.fetchall()) - if not out: + raw = cur.fetchall() + if raw is None: return [] - if isinstance(out, DataFrame): - out = out.to_dict(orient='records') - elif isinstance(out[0], (tuple, list)): + # pandas DataFrame + if hasattr(raw, 'to_dict') and hasattr(raw, 'columns'): + out = raw.to_dict(orient='records') + # polars DataFrame + elif hasattr(raw, 'to_dicts') and callable(raw.to_dicts): + out = raw.to_dicts() + # arrow Table + elif hasattr(raw, 'to_pydict') and callable(raw.to_pydict): + d = raw.to_pydict() + cols = list(d.keys()) + n = len(next(iter(d.values()))) if d else 0 + out = [{c: d[c][i] for c in cols} for i in range(n)] + # numpy ndarray + elif hasattr(raw, 'tolist') and hasattr(raw, 'ndim'): + rows = raw.tolist() if cur.description: names = [x[0] for x in cur.description] - if fix_names: - names = [under2camel(str(x).replace(' ', '')) for x in names] - out = [{k: v for k, v in zip(names, row)} for row in out] + out = [ + {k: v for k, v in zip(names, row)} + for row in rows + ] + else: + return [] + # list of tuples/namedtuples/dicts + else: + out = list(raw) + if not out: + return [] + if isinstance(out[0], dict): + pass # already dicts + elif isinstance(out[0], (tuple, list)): + if cur.description: + names = [x[0] for x in cur.description] + out = [ + {k: v for k, v in zip(names, row)} + for row in out + ] + else: + return [] + if not out: + return [] + # Apply camelCase name conversion if requested + if fix_names: + out = [ + { + under2camel(str(k).replace(' ', '')): v + for k, v in row.items() + } + for row in out + ] return out @abc.abstractmethod diff --git a/singlestoredb/converters.py b/singlestoredb/converters.py index ec9b7358..818c18ed 100644 --- a/singlestoredb/converters.py +++ b/singlestoredb/converters.py @@ -26,11 +26,7 @@ except (AttributeError, ImportError): has_pygeos = False -try: - import numpy - has_numpy = True -except ImportError: - has_numpy = False +from .utils._lazy_import import get_numpy try: import bson @@ -563,8 +559,9 @@ def float32_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.float32) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.float32) return map(float, json_loads(x)) @@ -591,8 +588,9 @@ def float32_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.float32) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.float32) return struct.unpack(f'<{len(x)//4}f', x) @@ -619,8 +617,9 @@ def float16_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.float16) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.float16) return map(float, json_loads(x)) @@ -647,8 +646,9 @@ def float16_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.float16) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.float16) return struct.unpack(f'<{len(x)//2}e', x) @@ -675,8 +675,9 @@ def float64_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.float64) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.float64) return map(float, json_loads(x)) @@ -703,8 +704,9 @@ def float64_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.float64) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.float64) return struct.unpack(f'<{len(x)//8}d', x) @@ -731,8 +733,9 @@ def int8_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int8) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int8) return map(int, json_loads(x)) @@ -759,8 +762,9 @@ def int8_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int8) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int8) return struct.unpack(f'<{len(x)}b', x) @@ -787,8 +791,9 @@ def int16_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int16) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int16) return map(int, json_loads(x)) @@ -815,8 +820,9 @@ def int16_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int16) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int16) return struct.unpack(f'<{len(x)//2}h', x) @@ -843,8 +849,9 @@ def int32_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int32) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int32) return map(int, json_loads(x)) @@ -871,8 +878,9 @@ def int32_vector_or_none(x: Optional[bytes]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int32) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int32) return struct.unpack(f'<{len(x)//4}l', x) @@ -899,8 +907,9 @@ def int64_vector_json_or_none(x: Optional[str]) -> Optional[Any]: if x is None: return None - if has_numpy: - return numpy.array(json_loads(x), dtype=numpy.int64) + np = get_numpy() + if np is not None: + return np.array(json_loads(x), dtype=np.int64) return map(int, json_loads(x)) @@ -928,8 +937,9 @@ def int64_vector_or_none(x: Optional[bytes]) -> Optional[Any]: return None # Bytes - if has_numpy: - return numpy.frombuffer(x, dtype=numpy.int64) + np = get_numpy() + if np is not None: + return np.frombuffer(x, dtype=np.int64) return struct.unpack(f'<{len(x)//8}l', x) diff --git a/singlestoredb/functions/dtypes.py b/singlestoredb/functions/dtypes.py index 0fe26a45..80aabd61 100644 --- a/singlestoredb/functions/dtypes.py +++ b/singlestoredb/functions/dtypes.py @@ -11,10 +11,9 @@ from ..converters import converters from ..mysql.converters import escape_item # type: ignore from ..utils.dtypes import DEFAULT_VALUES # noqa -from ..utils.dtypes import NUMPY_TYPE_MAP # noqa -from ..utils.dtypes import PANDAS_TYPE_MAP # noqa -from ..utils.dtypes import POLARS_TYPE_MAP # noqa -from ..utils.dtypes import PYARROW_TYPE_MAP # noqa +from ..utils.dtypes import get_numpy_type_map # noqa +from ..utils.dtypes import get_polars_type_map # noqa +from ..utils.dtypes import get_pyarrow_type_map # noqa DataType = Union[str, Callable[..., Any]] diff --git a/singlestoredb/functions/ext/collocated/__init__.py b/singlestoredb/functions/ext/collocated/__init__.py new file mode 100644 index 00000000..4b340031 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/__init__.py @@ -0,0 +1 @@ +"""High-performance collocated Python UDF server for SingleStoreDB.""" diff --git a/singlestoredb/functions/ext/collocated/__main__.py b/singlestoredb/functions/ext/collocated/__main__.py new file mode 100644 index 00000000..402050a4 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/__main__.py @@ -0,0 +1,132 @@ +""" +CLI entry point for the collocated Python UDF server. + +Usage:: + + python -m singlestoredb.functions.ext.collocated \\ + --extension myfuncs \\ + --extension-path /home/user/libs \\ + --socket /tmp/my-udf.sock + +Arguments match the Rust wasm-udf-server CLI for drop-in compatibility. +""" +import argparse +import logging +import os +import secrets +import sys +import tempfile +from typing import Any + +from .registry import setup_logging +from .server import Server + +logger = logging.getLogger('collocated') + + +def main(argv: Any = None) -> None: + parser = argparse.ArgumentParser( + prog='python -m singlestoredb.functions.ext.collocated', + description='High-performance collocated Python UDF server', + ) + parser.add_argument( + '--extension', + default=os.environ.get('EXTERNAL_UDF_EXTENSION', ''), + help=( + 'Python module to import (e.g. myfuncs). ' + 'Env: EXTERNAL_UDF_EXTENSION' + ), + ) + parser.add_argument( + '--extension-path', + default=os.environ.get('EXTERNAL_UDF_EXTENSION_PATH', ''), + help=( + 'Colon-separated search dirs for the module. ' + 'Env: EXTERNAL_UDF_EXTENSION_PATH' + ), + ) + parser.add_argument( + '--socket', + default=os.environ.get( + 'EXTERNAL_UDF_SOCKET_PATH', + os.path.join( + tempfile.gettempdir(), + f'singlestore-udf-{os.getpid()}-{secrets.token_hex(4)}.sock', + ), + ), + help=( + 'Unix socket path. ' + 'Env: EXTERNAL_UDF_SOCKET_PATH' + ), + ) + parser.add_argument( + '--n-workers', + type=int, + default=int(os.environ.get('EXTERNAL_UDF_N_WORKERS', '0')), + help=( + 'Worker threads (0 = CPU count). ' + 'Env: EXTERNAL_UDF_N_WORKERS' + ), + ) + parser.add_argument( + '--max-connections', + type=int, + default=int(os.environ.get('EXTERNAL_UDF_MAX_CONNECTIONS', '32')), + help=( + 'Socket backlog. ' + 'Env: EXTERNAL_UDF_MAX_CONNECTIONS' + ), + ) + parser.add_argument( + '--log-level', + default=os.environ.get('EXTERNAL_UDF_LOG_LEVEL', 'info'), + choices=['debug', 'info', 'warning', 'error'], + help=( + 'Logging level. ' + 'Env: EXTERNAL_UDF_LOG_LEVEL' + ), + ) + parser.add_argument( + '--process-mode', + default=os.environ.get('EXTERNAL_UDF_PROCESS_MODE', 'process'), + choices=['thread', 'process'], + help=( + 'Concurrency mode: "thread" uses a thread pool, ' + '"process" uses pre-fork workers for true CPU ' + 'parallelism. Env: EXTERNAL_UDF_PROCESS_MODE' + ), + ) + + args = parser.parse_args(argv) + + if not args.extension: + parser.error( + '--extension is required ' + '(or set EXTERNAL_UDF_EXTENSION env var)', + ) + + # Setup logging + level = getattr(logging, args.log_level.upper()) + setup_logging(level) + + config = { + 'extension': args.extension, + 'extension_path': args.extension_path, + 'socket': args.socket, + 'n_workers': args.n_workers, + 'max_connections': args.max_connections, + 'process_mode': args.process_mode, + } + + server = Server(config) + try: + server.run() + except RuntimeError as exc: + logger.error(str(exc)) + sys.exit(1) + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/singlestoredb/functions/ext/collocated/connection.py b/singlestoredb/functions/ext/collocated/connection.py new file mode 100644 index 00000000..5d025640 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/connection.py @@ -0,0 +1,424 @@ +""" +Connection handler: protocol, mmap I/O, request loop. + +Implements the binary socket protocol matching the Rust wasm-udf-server: +handshake, control signal dispatch, and UDF request loop with mmap I/O. +""" +from __future__ import annotations + +import array +import logging +import mmap +import os +import socket +import struct +import threading +import time +import traceback +from typing import TYPE_CHECKING + +from .control import dispatch_control_signal +from .registry import _has_accel +from .registry import _mmap_read +from .registry import _mmap_write +from .registry import _recv_exact as _c_recv_exact +from .registry import call_function + +if TYPE_CHECKING: + from .server import SharedRegistry + +logger = logging.getLogger('collocated.connection') + +# Protocol constants +PROTOCOL_VERSION = 1 +STATUS_OK = 200 +STATUS_BAD_REQUEST = 400 +STATUS_ERROR = 500 + +# Minimum output mmap size to avoid repeated ftruncate +_MIN_OUTPUT_SIZE = 128 * 1024 + +# Pre-pack the status OK header prefix to avoid per-request struct.pack +_STATUS_OK_PREFIX = struct.pack(' None: + """Handle a single client connection (runs in a thread pool worker).""" + try: + _handle_connection_inner( + conn, shared_registry, shutdown_event, pipe_write_fd, + ) + except Exception: + logger.error(f'Connection error:\n{traceback.format_exc()}') + finally: + try: + conn.close() + except OSError: + pass + + +def _handle_connection_inner( + conn: socket.socket, + shared_registry: SharedRegistry, + shutdown_event: threading.Event, + pipe_write_fd: int | None = None, +) -> None: + """Inner connection handler (may raise).""" + # --- Handshake --- + # Receive 16 bytes: [version: u64 LE][namelen: u64 LE] + header = _recv_exact_py(conn, 16) + if header is None: + return + version, namelen = struct.unpack(' _MAX_FUNCTION_NAME_LEN: + logger.warning(f'Function name too long: {namelen}') + return + + # Receive function name + 2 FDs via SCM_RIGHTS + fd_model = array.array('i', [0, 0]) + msg, ancdata, flags, addr = conn.recvmsg( + namelen, + socket.CMSG_LEN(2 * fd_model.itemsize), + ) + + # Validate ancdata and extract FDs + received_fds: list[int] = [] + try: + if len(ancdata) != 1: + logger.warning(f'Expected 1 ancdata, got {len(ancdata)}') + return + + level, type_, fd_data = ancdata[0] + if level != socket.SOL_SOCKET or type_ != socket.SCM_RIGHTS: + logger.warning( + f'Unexpected ancdata level={level} type={type_}', + ) + return + + if flags & getattr(socket, 'MSG_CTRUNC', 0): + logger.warning('Ancillary data was truncated (MSG_CTRUNC)') + return + + fd_array = array.array('i') + fd_array.frombytes(fd_data) + received_fds = list(fd_array) + + if len(received_fds) != 2: + logger.warning( + f'Expected 2 FDs, got {len(received_fds)}', + ) + return + + function_name = msg.decode('utf8') + input_fd, output_fd = received_fds[0], received_fds[1] + # Clear so finally doesn't close FDs we're handing off + received_fds = [] + finally: + # Close any received FDs if we're returning early + for fd in received_fds: + try: + os.close(fd) + except OSError: + pass + + # --- Control signal path --- + if function_name.startswith('@@'): + logger.info(f"Received control signal '{function_name}'") + _handle_control_signal( + conn, function_name, input_fd, output_fd, shared_registry, + pipe_write_fd, + ) + return + + # --- UDF request loop --- + logger.info(f"Received request for function '{function_name}'") + _handle_udf_loop( + conn, function_name, input_fd, output_fd, + shared_registry, shutdown_event, + ) + + +def _handle_control_signal( + conn: socket.socket, + signal_name: str, + input_fd: int, + output_fd: int, + shared_registry: SharedRegistry, + pipe_write_fd: int | None = None, +) -> None: + """Handle a @@-prefixed control signal (one-shot request-response).""" + try: + # Read 8-byte request length + len_buf = _recv_exact_py(conn, 8) + if len_buf is None: + return + length = struct.unpack(' 0: + if _has_accel: + request_data = _mmap_read(input_fd, length) + else: + mem = mmap.mmap( + input_fd, length, mmap.MAP_SHARED, mmap.PROT_READ, + ) + try: + request_data = bytes(mem[:length]) + finally: + mem.close() + + # Dispatch + result = dispatch_control_signal( + signal_name, request_data, shared_registry, pipe_write_fd, + ) + + if result.ok: + # Write response to output mmap + response_bytes = result.data.encode('utf8') + response_size = len(response_bytes) + if _has_accel: + _mmap_write( + output_fd, response_bytes, + max(_MIN_OUTPUT_SIZE, response_size), + ) + else: + os.ftruncate(output_fd, max(_MIN_OUTPUT_SIZE, response_size)) + os.lseek(output_fd, 0, os.SEEK_SET) + _write_all_fd(output_fd, response_bytes) + + # Send [status=200, size] + conn.sendall(struct.pack(' None: + """Handle the UDF request loop for a single function.""" + # Track output mmap size to avoid repeated ftruncate + current_output_size = 0 + + # Choose recv implementation: C accel or Python fallback + use_accel = _has_accel + sock_fd = conn.fileno() + + if use_accel: + # Keep the fd in blocking mode. The C recv_exact uses poll() + # internally with a timeout, avoiding the interaction between + # Python's settimeout() (which sets O_NONBLOCK on the fd) and + # direct fd-level recv() in the C code. + pass + else: + # Python fallback: settimeout makes recv_into raise + # socket.timeout (alias for TimeoutError) when no data arrives. + conn.settimeout(0.1) + + # Profiling accumulators + profile = _PROFILE + if profile: + n_requests = 0 + t_recv = 0.0 + t_mmap_read = 0.0 + t_call = 0.0 + t_mmap_write = 0.0 + t_send = 0.0 + + try: + # Get thread-local registry + registry = shared_registry.get_thread_local_registry() + + while not shutdown_event.is_set(): + # Read 8-byte request length (with timeout for shutdown checks) + try: + if use_accel: + if profile: + t0 = time.monotonic() + len_buf = _c_recv_exact(sock_fd, 8, 100) + if profile: + t_recv += time.monotonic() - t0 + else: + if profile: + t0 = time.monotonic() + len_buf = _recv_exact_py(conn, 8) + if profile: + t_recv += time.monotonic() - t0 + except TimeoutError: + continue + except OSError: + break + + if len_buf is None: + break + length = struct.unpack(' current_output_size: + _mmap_write(output_fd, output_data, needed) + current_output_size = needed + else: + _mmap_write(output_fd, output_data, 0) + else: + needed = max(_MIN_OUTPUT_SIZE, response_size) + if needed > current_output_size: + os.ftruncate(output_fd, needed) + current_output_size = needed + os.lseek(output_fd, 0, os.SEEK_SET) + _write_all_fd(output_fd, output_data) + if profile: + t_mmap_write += time.monotonic() - t0 + + # Send [status=200, size] + if profile: + t0 = time.monotonic() + conn.sendall( + _STATUS_OK_PREFIX + struct.pack(' 0: + t_total = ( + t_recv + t_mmap_read + t_call + t_mmap_write + t_send + ) / n_requests * 1e6 + logger.info( + f"PROFILE '{function_name}' " + f'n={n_requests} ' + f'recv={t_recv / n_requests * 1e6:.1f}us ' + f'mmap_read={t_mmap_read / n_requests * 1e6:.1f}us ' + f'call={t_call / n_requests * 1e6:.1f}us ' + f'mmap_write={t_mmap_write / n_requests * 1e6:.1f}us ' + f'send={t_send / n_requests * 1e6:.1f}us ' + f'total={t_total:.1f}us', + ) + + +def _recv_exact_py(sock: socket.socket, n: int) -> bytes | None: + """Receive exactly n bytes, or return None on EOF.""" + buf = bytearray(n) + view = memoryview(buf) + pos = 0 + while pos < n: + try: + nbytes = sock.recv_into(view[pos:]) + except TimeoutError: + if pos == 0: + raise + # Partial message already consumed — must finish it. + # Remove timeout to avoid protocol desync. + sock.settimeout(None) + continue + if nbytes == 0: + return None + pos += nbytes + return bytes(buf) + + +def _write_all_fd(fd: int, data: bytes) -> None: + """Write all bytes to a file descriptor, handling partial writes.""" + view = memoryview(data) + written = 0 + while written < len(data): + try: + n = os.write(fd, view[written:]) + except InterruptedError: + continue + if n == 0: + raise RuntimeError('short write to output fd') + written += n diff --git a/singlestoredb/functions/ext/collocated/control.py b/singlestoredb/functions/ext/collocated/control.py new file mode 100644 index 00000000..b2d5c59d --- /dev/null +++ b/singlestoredb/functions/ext/collocated/control.py @@ -0,0 +1,134 @@ +""" +Control signal dispatch for @@health, @@functions, @@register. + +Matches the Rust wasm-udf-server's dispatch_control_signal behavior. +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .registry import describe_functions_json + +if TYPE_CHECKING: + from .server import SharedRegistry + +logger = logging.getLogger('collocated.control') + + +@dataclass +class ControlResult: + """Result of a control signal dispatch.""" + ok: bool + data: str # JSON response on success, error message on failure + + +def dispatch_control_signal( + signal_name: str, + request_data: bytes, + shared_registry: SharedRegistry, + pipe_write_fd: int | None = None, +) -> ControlResult: + """Dispatch a control signal to the appropriate handler.""" + try: + if signal_name == '@@health': + return _handle_health() + elif signal_name == '@@functions': + return _handle_functions(shared_registry) + elif signal_name == '@@register': + return _handle_register( + request_data, shared_registry, pipe_write_fd, + ) + else: + return ControlResult( + ok=False, + data=f'Unknown control signal: {signal_name}', + ) + except Exception as e: + return ControlResult(ok=False, data=str(e)) + + +def _handle_health() -> ControlResult: + """Handle @@health: return status ok.""" + return ControlResult(ok=True, data='{"status":"ok"}') + + +def _handle_functions(shared_registry: SharedRegistry) -> ControlResult: + """Handle @@functions: return function descriptions.""" + registry = shared_registry.get_thread_local_registry() + json_str = describe_functions_json(registry) + return ControlResult(ok=True, data=f'{{"functions":{json_str}}}') + + +def _handle_register( + request_data: bytes, + shared_registry: SharedRegistry, + pipe_write_fd: int | None = None, +) -> ControlResult: + """Handle @@register: register a new function dynamically. + + If ``pipe_write_fd`` is not None (process mode), the registration + payload is written to the pipe so the main process can update its + own registry and re-fork all workers. + """ + if not request_data: + return ControlResult(ok=False, data='Missing registration payload') + + try: + body = json.loads(request_data) + except json.JSONDecodeError as e: + return ControlResult(ok=False, data=f'Invalid JSON: {e}') + + function_name = body.get('function_name') + if not function_name: + return ControlResult( + ok=False, data='Missing required field: function_name', + ) + + args = body.get('args') + if not isinstance(args, list): + return ControlResult( + ok=False, data='Missing required field: args (must be an array)', + ) + + returns = body.get('returns') + if not isinstance(returns, list): + return ControlResult( + ok=False, + data='Missing required field: returns (must be an array)', + ) + + func_body = body.get('body') + if not func_body: + return ControlResult( + ok=False, data='Missing required field: body', + ) + + replace = body.get('replace', False) + + # Build signature JSON matching describe-functions schema + signature = json.dumps({ + 'name': function_name, + 'args': args, + 'returns': returns, + }) + + try: + shared_registry.create_function(signature, func_body, replace) + except Exception as e: + return ControlResult(ok=False, data=str(e)) + + # Notify main process so it can re-fork workers with updated state + if pipe_write_fd is not None: + from .server import _write_pipe_message + payload = json.dumps({ + 'signature_json': signature, + 'code': func_body, + 'replace': replace, + }).encode() + _write_pipe_message(pipe_write_fd, payload) + + logger.info(f"@@register: added function '{function_name}'") + return ControlResult(ok=True, data='{"status":"ok"}') diff --git a/singlestoredb/functions/ext/collocated/registry.py b/singlestoredb/functions/ext/collocated/registry.py new file mode 100644 index 00000000..3dadcbf2 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/registry.py @@ -0,0 +1,475 @@ +""" +Function registry for UDF discovery, registration, and invocation. + +This module contains the core FunctionRegistry class (moved from +wasm/udf_handler.py) plus standalone call_function() and +describe_functions_json() helpers. Both the WASM handler and the +collocated server use these directly. +""" +import inspect +import json +import logging +import os +import sys +import traceback +import types +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from singlestoredb.functions.ext.rowdat_1 import dump as _dump_rowdat_1 +from singlestoredb.functions.ext.rowdat_1 import load as _load_rowdat_1 +from singlestoredb.functions.signature import get_signature +from singlestoredb.mysql.constants import FIELD_TYPE as ft + +_accel_error: Optional[str] = None +try: + from _singlestoredb_accel import call_function_accel as _call_function_accel + from _singlestoredb_accel import mmap_read as _mmap_read + from _singlestoredb_accel import mmap_write as _mmap_write + from _singlestoredb_accel import recv_exact as _recv_exact + _has_accel = True + logging.getLogger(__name__).info('_singlestoredb_accel loaded successfully') +except Exception as e: + _has_accel = False + _accel_error = str(e) + _mmap_read = None + _mmap_write = None + _recv_exact = None + logging.getLogger(__name__).warning( + '_singlestoredb_accel failed to load: %s', e, + ) + + +class _TracingFormatter(logging.Formatter): + """Match Rust tracing-subscriber's colored output format.""" + + _RESET = '\033[0m' + _DIM = '\033[2m' + _BOLD = '\033[1m' + _LEVEL_COLORS = { + 'DEBUG': '\033[34m', # blue + 'INFO': '\033[32m', # green + 'WARNING': '\033[33m', # yellow + 'ERROR': '\033[31m', # red + 'CRITICAL': '\033[31m', # red + } + + def formatTime( + self, + record: logging.LogRecord, + datefmt: Optional[str] = None, + ) -> str: + dt = datetime.fromtimestamp(record.created, tz=timezone.utc) + return dt.strftime('%Y-%m-%dT%H:%M:%S.') + f'{dt.microsecond:06d}Z' + + def format(self, record: logging.LogRecord) -> str: + ts = self.formatTime(record) + color = self._LEVEL_COLORS.get(record.levelname, '') + level = f'{color}{self._BOLD}{record.levelname:>5}{self._RESET}' + name = f'{self._DIM}{record.name}{self._RESET}' + msg = record.getMessage() + return f'{self._DIM}{ts}{self._RESET} {level} {name}: {msg}' + + +def setup_logging(level: int = logging.INFO) -> None: + """Configure root logging with the tracing formatter.""" + handler = logging.StreamHandler() + handler.setFormatter(_TracingFormatter()) + logging.basicConfig(level=level, handlers=[handler]) + + +# Map dtype strings to rowdat_1 type codes for wire serialization. +# rowdat_1 always uses 8-byte encoding for integers and doubles for floats, +# so all int types collapse to LONGLONG and all float types to DOUBLE. +# Uses negative values for unsigned ints / binary data. +rowdat_1_type_map: Dict[str, int] = { + 'bool': ft.LONGLONG, + 'int8': ft.LONGLONG, + 'int16': ft.LONGLONG, + 'int32': ft.LONGLONG, + 'int64': ft.LONGLONG, + 'uint8': -ft.LONGLONG, + 'uint16': -ft.LONGLONG, + 'uint32': -ft.LONGLONG, + 'uint64': -ft.LONGLONG, + 'float32': ft.DOUBLE, + 'float64': ft.DOUBLE, + 'str': ft.STRING, + 'bytes': -ft.STRING, +} + +# Map dtype strings to Python type annotation strings for code generation. +_dtype_to_python: Dict[str, str] = { + 'bool': 'bool', + 'int8': 'int', + 'int16': 'int', + 'int32': 'int', + 'int64': 'int', + 'int': 'int', + 'uint8': 'int', + 'uint16': 'int', + 'uint32': 'int', + 'uint64': 'int', + 'float32': 'float', + 'float64': 'float', + 'float': 'float', + 'str': 'str', + 'bytes': 'bytes', +} + +logger = logging.getLogger('udf_handler') + + +class FunctionRegistry: + """Registry of discovered UDF functions.""" + + def __init__(self) -> None: + self.functions: Dict[str, Dict[str, Any]] = {} + + def initialize(self) -> None: + """Initialize and discover UDF functions from loaded modules. + + Scans sys.modules for any module containing @udf-decorated + functions. No _exports.py is needed -- modules just need to be + imported before initialize() is called (componentize-py captures + them at build time). + """ + self._discover_udf_functions() + + @staticmethod + def _is_stdlib_or_infra(mod_name: str, mod_file: str) -> bool: + """Check if a module is stdlib or infrastructure (not user UDF code). + + Uses the module's __file__ path to detect stdlib modules + (under sys.prefix but not in site-packages) rather than + maintaining a hardcoded list of names. + """ + _infra = frozenset({ + 'udf_handler', + }) + if mod_name in _infra: + return True + + real_file = os.path.realpath(mod_file) + real_prefix = os.path.realpath(sys.prefix) + + if real_file.startswith(real_prefix + os.sep): + if 'site-packages' not in real_file: + return True + + return False + + def _discover_udf_functions(self) -> None: + """Discover @udf functions by scanning sys.modules. + + Uses a two-pass approach: first, identify candidate modules + that import FunctionHandler (the convention for UDF modules). + Then extract @udf-decorated functions from those modules. + Modules without a __file__ (built-in/frozen) and stdlib/ + infrastructure modules are skipped automatically. + """ + # Import here to avoid circular dependency at module level + from .wasm import FunctionHandler + + found_modules = [] + for mod_name, mod in list(sys.modules.items()): + if mod is None: + continue + if not isinstance(mod, types.ModuleType): + continue + mod_file = getattr(mod, '__file__', None) + if mod_file is None: + continue + + # Short-circuit: only scan modules that import + # FunctionHandler (the convention for UDF modules) + if not any( + obj is FunctionHandler + for obj in vars(mod).values() + ): + continue + + if self._is_stdlib_or_infra(mod_name, mod_file): + continue + + self._extract_functions(mod) + if any( + hasattr(obj, '_singlestoredb_attrs') + for _, obj in inspect.getmembers(mod) + if inspect.isfunction(obj) + ): + found_modules.append(mod_name) + + if found_modules: + logger.info( + f'Discovered UDF functions from modules: ' + f'{", ".join(sorted(found_modules))}', + ) + else: + logger.warning( + 'No modules with @udf functions found in sys.modules.', + ) + + def _extract_functions(self, module: Any) -> None: + """Extract @udf-decorated functions from a module.""" + for name, obj in inspect.getmembers(module): + if name.startswith('_'): + continue + + if not callable(obj): + continue + + if not inspect.isfunction(obj): + continue + + if not hasattr(obj, '_singlestoredb_attrs'): + continue + + try: + sig = get_signature(obj) + if sig and sig.get('args') is not None and sig.get('returns'): + self._register_function(obj, name, sig) + except (TypeError, ValueError): + pass + + def _build_json_descriptions( + self, + func_names: List[str], + ) -> List[Dict[str, Any]]: + """Build JSON-serializable descriptions for the given function names.""" + descriptions = [] + for func_name in func_names: + func_info = self.functions[func_name] + sig = func_info['signature'] + args = [] + for arg in sig['args']: + args.append({ + 'name': arg['name'], + 'dtype': arg['dtype'], + 'sql': arg['sql'], + }) + returns = [] + for ret in sig['returns']: + returns.append({ + 'name': ret.get('name') or None, + 'dtype': ret['dtype'], + 'sql': ret['sql'], + }) + descriptions.append({ + 'name': func_name, + 'args': args, + 'returns': returns, + 'args_data_format': sig.get('args_data_format') or 'scalar', + 'returns_data_format': ( + sig.get('returns_data_format') or 'scalar' + ), + 'function_type': sig.get('function_type') or 'udf', + 'doc': sig.get('doc'), + }) + return descriptions + + @staticmethod + def _python_type_annotation(dtype: str) -> str: + """Convert a dtype string to a Python type annotation.""" + nullable = dtype.endswith('?') + base = dtype.rstrip('?') + py_type = _dtype_to_python.get(base) + if py_type is None: + raise ValueError(f'Unsupported dtype: {dtype!r}') + if nullable: + return f'Optional[{py_type}]' + return py_type + + @staticmethod + def _build_python_code( + sig: Dict[str, Any], + body: str, + ) -> str: + """Build a complete @udf-decorated Python function from sig + body.""" + func_name = sig['name'] + args = sig.get('args', []) + returns = sig.get('returns', []) + + params = [] + for arg in args: + ann = FunctionRegistry._python_type_annotation(arg['dtype']) + params.append(f'{arg["name"]}: {ann}') + params_str = ', '.join(params) + + if len(returns) == 0: + ret_ann = 'None' + elif len(returns) == 1: + ret_ann = FunctionRegistry._python_type_annotation( + returns[0]['dtype'], + ) + else: + parts = [ + FunctionRegistry._python_type_annotation(r['dtype']) + for r in returns + ] + ret_ann = f'Tuple[{", ".join(parts)}]' + + indented_body = '\n'.join( + f' {line}' for line in body.splitlines() + ) + + return ( + 'from singlestoredb.functions import udf\n' + 'from typing import Optional, Tuple\n' + '\n' + '@udf\n' + f'def {func_name}({params_str}) -> {ret_ann}:\n' + f'{indented_body}\n' + ) + + def create_function( + self, + signature_json: str, + code: str, + replace: bool, + ) -> List[str]: + """Register a function from its signature and function body. + + Args: + signature_json: JSON object matching the describe-functions + element schema (must contain a 'name' field) + code: Function body (e.g. "return x * 3"), not full source + replace: If False, raise an error if the function already exists + + Returns: + List of newly registered function names + """ + sig = json.loads(signature_json) + func_name = sig.get('name') + if not func_name: + raise ValueError( + 'signature JSON must contain a "name" field', + ) + + if not replace and func_name in self.functions: + raise ValueError( + f'Function "{func_name}" already exists ' + f'(use replace=true to overwrite)', + ) + + if replace and func_name in self.functions: + del self.functions[func_name] + + full_code = self._build_python_code(sig, code) + + name = '__main__' + compiled = compile(full_code, f'<{name}>', 'exec') + + if name in sys.modules: + module = sys.modules[name] + else: + module = types.ModuleType(name) + module.__file__ = f'<{name}>' + sys.modules[name] = module + exec(compiled, module.__dict__) # noqa: S102 + + before_names = set(self.functions.keys()) + self._extract_functions(module) + new_names = [k for k in self.functions if k not in before_names] + + if not new_names: + raise ValueError( + f'Function "{func_name}" was not registered. ' + f'Check that the signature dtypes are supported.', + ) + + logger.info( + f'create_function({func_name}): registered ' + f'{len(new_names)} function(s): {", ".join(new_names)}', + ) + return new_names + + def _register_function( + self, + func: Callable[..., Any], + func_name: str, + sig: Dict[str, Any], + ) -> None: + """Register a function under its bare name.""" + full_name = sig.get('name') or func_name + + arg_types: List[Tuple[str, int]] = [] + for arg in sig['args']: + dtype = arg['dtype'].replace('?', '') + if dtype not in rowdat_1_type_map: + logger.warning( + f"Skipping {full_name}: unsupported arg dtype '{dtype}'", + ) + return + arg_types.append((arg['name'], rowdat_1_type_map[dtype])) + + return_types: List[int] = [] + for ret in sig['returns']: + dtype = ret['dtype'].replace('?', '') + if dtype not in rowdat_1_type_map: + logger.warning( + f'Skipping {full_name}: no type mapping for {dtype}', + ) + return + return_types.append(rowdat_1_type_map[dtype]) + + self.functions[full_name] = { + 'func': func, + 'arg_types': arg_types, + 'return_types': return_types, + 'signature': sig, + } + + +def call_function( + registry: FunctionRegistry, + name: str, + input_data: bytes, +) -> bytes: + """Call a registered UDF by name using the C accelerator or fallback. + + This is the hot-path function used by both the WASM handler and + the collocated server. + """ + if name not in registry.functions: + raise ValueError(f'unknown function: {name}') + + func_info = registry.functions[name] + func = func_info['func'] + arg_types = func_info['arg_types'] + return_types = func_info['return_types'] + + try: + if _has_accel: + return _call_function_accel( + colspec=arg_types, + returns=return_types, + data=input_data, + func=func, + ) + + row_ids, rows = _load_rowdat_1(arg_types, input_data) + results = [] + for row in rows: + result = func(*row) + if not isinstance(result, (tuple, list)): + result = [result] + results.append(list(result)) + return bytes(_dump_rowdat_1(return_types, row_ids, results)) + + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'Error calling {name}: {e}\n{tb}') + + +def describe_functions_json(registry: FunctionRegistry) -> str: + """Serialize all function descriptions as a JSON array string.""" + func_names = list(registry.functions.keys()) + descriptions = registry._build_json_descriptions(func_names) + return json.dumps(descriptions) diff --git a/singlestoredb/functions/ext/collocated/server.py b/singlestoredb/functions/ext/collocated/server.py new file mode 100644 index 00000000..60e31d17 --- /dev/null +++ b/singlestoredb/functions/ext/collocated/server.py @@ -0,0 +1,499 @@ +""" +Server lifecycle: accept loop, thread pool, shutdown. + +Mirrors the Rust wasm-udf-server architecture with a ThreadPoolExecutor +for concurrent request handling and a SharedRegistry with generation- +counter caching for thread-safe live reload. +""" +import importlib +import json +import logging +import multiprocessing +import os +import select +import signal +import socket +import struct +import sys +import threading +import traceback +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from .connection import _write_all_fd +from .connection import handle_connection +from .registry import FunctionRegistry + +logger = logging.getLogger('collocated.server') + + +def _read_pipe_message(fd: int) -> Optional[bytes]: + """Read a length-prefixed message from a pipe fd. + + Wire format: [u32 LE length][payload]. + Returns None on EOF or short read. + """ + try: + len_buf = b'' + while len(len_buf) < 4: + chunk = os.read(fd, 4 - len(len_buf)) + if not chunk: + return None + len_buf += chunk + length = struct.unpack(' None: + """Write a length-prefixed message to a pipe fd. + + Wire format: [u32 LE length][payload]. + """ + header = struct.pack(' None: + self._lock = threading.Lock() + self._generation: int = 0 + self._code_blocks: List[Tuple[str, str, bool]] = [] + self._base_registry: Optional[FunctionRegistry] = None + self._local = threading.local() + + def set_base_registry(self, registry: FunctionRegistry) -> None: + """Set the base registry (after initial module import + init).""" + with self._lock: + self._base_registry = registry + + @property + def generation(self) -> int: + return self._generation + + def create_function( + self, + signature_json: str, + code: str, + replace: bool, + ) -> List[str]: + """Register a new function and bump the generation counter. + + Thread-safe: acquires the lock, validates via a temporary + registry, stores the code block, and increments generation. + """ + with self._lock: + # Validate on a temporary registry first + test_registry = self._build_fresh_registry() + new_names = test_registry.create_function( + signature_json, code, replace, + ) + # Success: store the code block and bump generation + self._code_blocks.append((signature_json, code, replace)) + self._generation += 1 + logger.info( + f'SharedRegistry: generation={self._generation}, ' + f'code_blocks={len(self._code_blocks)}', + ) + return new_names + + def get_thread_local_registry(self) -> FunctionRegistry: + """Get or refresh the thread-local cached registry. + + Cheap int comparison on the hot path; only rebuilds on + generation mismatch. + """ + cached = getattr(self._local, 'cached', None) + if cached is not None: + cached_gen, cached_reg = cached + if cached_gen == self._generation: + return cached_reg + + # Rebuild from base + code blocks + with self._lock: + registry = self._build_fresh_registry() + gen = self._generation + + self._local.cached = (gen, registry) + return registry + + def _build_fresh_registry(self) -> FunctionRegistry: + """Build a fresh registry with base functions + all code blocks. + + Must be called with self._lock held. + """ + registry = FunctionRegistry() + # Copy base functions + if self._base_registry is not None: + registry.functions = dict(self._base_registry.functions) + # Replay code blocks + for sig_json, code, replace in self._code_blocks: + registry.create_function(sig_json, code, replace) + return registry + + +class Server: + """Collocated UDF server with Unix socket + thread pool.""" + + def __init__(self, config: Dict[str, Any]) -> None: + self.config = config + self.shared_registry = SharedRegistry() + self.shutdown_event = threading.Event() + + def run(self) -> None: + """Run the server: import modules, bind socket, accept loop.""" + # 1. Import user modules & initialize registry + registry = self._initialize_registry() + self.shared_registry.set_base_registry(registry) + + # 2. Create & bind Unix socket + server_sock = self._bind_socket() + + # 3. Determine worker count and process mode + n_workers = self.config.get('n_workers', 0) + if n_workers <= 0: + n_workers = os.cpu_count() or 4 + + process_mode = self.config.get('process_mode', 'process') + + # 4. Signal handling (main process) + def _signal_handler(signum: int, frame: Any) -> None: + logger.info(f'Received signal {signum}, shutting down...') + self.shutdown_event.set() + + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + # 5. Dispatch to mode-specific loop + sock_path = self.config['socket'] + try: + if process_mode == 'process': + self._run_process_mode(server_sock, n_workers) + else: + self._run_thread_mode(server_sock, n_workers) + finally: + server_sock.close() + try: + os.unlink(sock_path) + except OSError: + pass + logger.info('Server stopped.') + + def _bind_socket(self) -> socket.socket: + """Create, bind, and listen on the Unix domain socket.""" + sock_path = self.config['socket'] + if os.path.exists(sock_path): + os.unlink(sock_path) + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(sock_path) + os.chmod(sock_path, 0o600) + + backlog = self.config.get('max_connections', 32) + server_sock.listen(backlog) + logger.info(f'Listening on {sock_path} (backlog={backlog})') + return server_sock + + def _run_thread_mode( + self, + server_sock: socket.socket, + n_workers: int, + ) -> None: + """Accept loop using a ThreadPoolExecutor.""" + pool = ThreadPoolExecutor(max_workers=n_workers) + logger.info(f'Thread pool: {n_workers} workers') + + try: + while not self.shutdown_event.is_set(): + readable, _, _ = select.select( + [server_sock], [], [], 0.1, + ) + if not readable: + continue + + conn, _ = server_sock.accept() + pool.submit( + handle_connection, + conn, + self.shared_registry, + self.shutdown_event, + ) + finally: + logger.info('Shutting down thread pool...') + pool.shutdown(wait=True) + + def _run_process_mode( + self, + server_sock: socket.socket, + n_workers: int, + ) -> None: + """Pre-fork worker pool for true CPU parallelism. + + Each worker gets a pipe back to the main process. When a worker + receives @@register, it writes the registration payload to its + pipe. The main process reads it, updates its own registry, then + kills and re-forks all workers so every worker has the updated + registry state. + """ + try: + ctx = multiprocessing.get_context('fork') + except ValueError: + raise RuntimeError( + "Process mode requires 'fork' multiprocessing context, " + 'which is not available on this platform. ' + "Use process_mode='thread' instead.", + ) + # workers[wid] = (process, pipe_read_fd) + workers: Dict[ + int, + Tuple[multiprocessing.process.BaseProcess, int], + ] = {} + + def _spawn_worker(worker_id: int) -> Tuple[ + multiprocessing.process.BaseProcess, int, + ]: + pipe_r, pipe_w = os.pipe() + p = ctx.Process( + target=self._worker_process_main, + args=(server_sock, worker_id, pipe_w), + daemon=True, + ) + p.start() + # Close the write end in the parent — only the child writes + os.close(pipe_w) + logger.info( + f'Started worker {worker_id} (pid={p.pid})', + ) + return p, pipe_r + + def _kill_all_workers() -> None: + """SIGTERM all workers, wait, then SIGKILL stragglers.""" + for wid, (proc, pipe_r) in workers.items(): + if proc.is_alive(): + assert proc.pid is not None + os.kill(proc.pid, signal.SIGTERM) + for wid, (proc, pipe_r) in workers.items(): + proc.join(timeout=5.0) + if proc.is_alive(): + logger.warning( + f'Worker {wid} (pid={proc.pid}) ' + f'did not exit, terminating...', + ) + proc.terminate() + proc.join(timeout=2.0) + # Close all pipe read fds + for wid, (proc, pipe_r) in workers.items(): + try: + os.close(pipe_r) + except OSError: + pass + + def _respawn_all_workers() -> None: + """Kill all workers and re-fork them with fresh state.""" + _kill_all_workers() + workers.clear() + for i in range(n_workers): + workers[i] = _spawn_worker(i) + + # Fork initial workers + logger.info( + f'Process pool: spawning {n_workers} workers', + ) + for i in range(n_workers): + workers[i] = _spawn_worker(i) + + # Monitor loop using poll() over pipe read fds + try: + while not self.shutdown_event.is_set(): + poller = select.poll() + fd_to_wid: Dict[int, int] = {} + for wid, (proc, pipe_r) in workers.items(): + poller.register( + pipe_r, select.POLLIN | select.POLLHUP, + ) + fd_to_wid[pipe_r] = wid + + events = poller.poll(500) # 500ms timeout + + registration_received = False + for fd, event in events: + if fd not in fd_to_wid: + continue + wid = fd_to_wid[fd] + + if event & select.POLLIN: + msg = _read_pipe_message(fd) + if msg is not None: + # Apply registration to main's registry + try: + body = json.loads(msg) + self.shared_registry.create_function( + body['signature_json'], + body['code'], + body['replace'], + ) + logger.info( + 'Main process: applied ' + '@@register from worker ' + f'{wid}, will re-fork all ' + 'workers', + ) + registration_received = True + except Exception: + logger.error( + 'Main process: failed to ' + 'apply @@register:\n' + f'{traceback.format_exc()}', + ) + elif event & select.POLLHUP: + # Worker died — will be respawned below + pass + + if registration_received: + _respawn_all_workers() + continue + + # Check for dead workers and respawn individually + for wid, (proc, pipe_r) in list(workers.items()): + if not proc.is_alive(): + exitcode = proc.exitcode + if not self.shutdown_event.is_set(): + logger.warning( + f'Worker {wid} (pid={proc.pid}) ' + f'exited with code {exitcode}, ' + f'restarting...', + ) + try: + os.close(pipe_r) + except OSError: + pass + workers[wid] = _spawn_worker(wid) + finally: + logger.info('Shutting down worker processes...') + _kill_all_workers() + + def _worker_process_main( + self, + server_sock: socket.socket, + worker_id: int, + pipe_w: int, + ) -> None: + """Entry point for each forked worker process. + + Uses ``self.shared_registry`` inherited via fork (contains the + main process's current state). ``pipe_w`` is used to notify the + main process when @@register is handled so it can re-fork all + workers. + """ + try: + local_shutdown = threading.Event() + + def _worker_signal_handler( + signum: int, + frame: Any, + ) -> None: + local_shutdown.set() + + signal.signal(signal.SIGTERM, _worker_signal_handler) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # WARNING: setblocking(False) sets O_NONBLOCK on the open + # file description, which is shared across all forked + # processes. This is intentional here — all workers need + # non-blocking accept() to handle the thundering-herd race, + # and the parent process never calls accept() on this + # socket. Do NOT add blocking operations on this socket + # in the parent process after workers are forked. + server_sock.setblocking(False) + + registry = self.shared_registry.get_thread_local_registry() + logger.info( + f'Worker {worker_id} (pid={os.getpid()}) ready, ' + f'{len(registry.functions)} function(s)', + ) + + # Accept loop + while not local_shutdown.is_set(): + readable, _, _ = select.select( + [server_sock], [], [], 0.1, + ) + if not readable: + continue + + try: + conn, _ = server_sock.accept() + except BlockingIOError: + # Another worker won the accept race + continue + except OSError: + if local_shutdown.is_set(): + break + raise + + handle_connection( + conn, + self.shared_registry, + local_shutdown, + pipe_write_fd=pipe_w, + ) + except Exception: + logger.error( + f'Worker {worker_id} crashed:\n' + f'{traceback.format_exc()}', + ) + raise + finally: + try: + os.close(pipe_w) + except OSError: + pass + + def _initialize_registry(self) -> FunctionRegistry: + """Import the extension module and discover @udf functions.""" + extension = self.config['extension'] + extension_path = self.config.get('extension_path', '') + + # Prepend extension path directories to sys.path + if extension_path: + for p in reversed(extension_path.split(':')): + p = p.strip() + if p and p not in sys.path: + sys.path.insert(0, p) + logger.info(f'Added to sys.path: {p}') + + # Import the extension module + logger.info(f'Importing extension module: {extension}') + importlib.import_module(extension) + + # Initialize registry (discovers @udf functions from sys.modules) + registry = FunctionRegistry() + registry.initialize() + + func_count = len(registry.functions) + if func_count == 0: + raise RuntimeError( + f'No @udf functions found after importing {extension!r}', + ) + logger.info(f'Discovered {func_count} function(s)') + for name in sorted(registry.functions): + logger.info(f' function: {name}') + + return registry diff --git a/singlestoredb/functions/ext/collocated/wasm.py b/singlestoredb/functions/ext/collocated/wasm.py new file mode 100644 index 00000000..e528e78f --- /dev/null +++ b/singlestoredb/functions/ext/collocated/wasm.py @@ -0,0 +1,72 @@ +""" +Thin WIT adapter over FunctionRegistry. + +This module provides the FunctionHandler class that implements the +singlestore:udf/function-handler WIT interface by delegating to the +shared FunctionRegistry in registry.py. +""" +import logging +import traceback + +from .registry import _accel_error +from .registry import _has_accel +from .registry import call_function +from .registry import describe_functions_json +from .registry import FunctionRegistry +from .registry import setup_logging + +logger = logging.getLogger('udf_handler') + +# Global registry instance (used by WASM component runtime) +_registry = FunctionRegistry() + + +class FunctionHandler: + """Implementation of the singlestore:udf/function-handler interface.""" + + def initialize(self) -> None: + """Initialize and discover UDF functions from loaded modules.""" + setup_logging() + if _has_accel: + logger.info('Using accelerated C call_function_accel loop') + else: + logger.info('Using pure Python call_function loop') + if _accel_error: + logger.warning( + '_singlestoredb_accel failed to load: %s', + _accel_error, + ) + _registry.initialize() + + def call_function(self, name: str, input_data: bytes) -> bytes: + """Call a function by its registered name.""" + return call_function(_registry, name, input_data) + + def describe_functions(self) -> str: + """Describe all functions as a JSON array. + + Returns a JSON string containing an array of function + description objects. + """ + try: + return describe_functions_json(_registry) + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'{e}\n{tb}') + + def create_function( + self, + signature: str, + code: str, + replace: bool, + ) -> None: + """Register a function from its signature and function body. + + The ``code`` parameter should contain the function body, not a + full ``def`` statement or ``@udf``-decorated source. + """ + try: + _registry.create_function(signature, code, replace) + except Exception as e: + tb = traceback.format_exc() + raise RuntimeError(f'{e}\n{tb}') diff --git a/singlestoredb/functions/ext/json.py b/singlestoredb/functions/ext/json.py index 05710247..619c3ad7 100644 --- a/singlestoredb/functions/ext/json.py +++ b/singlestoredb/functions/ext/json.py @@ -7,10 +7,9 @@ from typing import TYPE_CHECKING from ..dtypes import DEFAULT_VALUES -from ..dtypes import NUMPY_TYPE_MAP -from ..dtypes import PANDAS_TYPE_MAP -from ..dtypes import POLARS_TYPE_MAP -from ..dtypes import PYARROW_TYPE_MAP +from ..dtypes import get_numpy_type_map +from ..dtypes import get_polars_type_map +from ..dtypes import get_pyarrow_type_map from ..dtypes import PYTHON_CONVERTERS if TYPE_CHECKING: @@ -140,7 +139,7 @@ def load_pandas( ( pd.Series( data, index=index, name=spec[0], - dtype=PANDAS_TYPE_MAP[spec[1]], + dtype=get_numpy_type_map()[spec[1]], ), pd.Series(mask, index=index, dtype=np.longlong), ) @@ -172,7 +171,7 @@ def load_polars( return pl.Series(None, row_ids, dtype=pl.Int64), \ [ ( - pl.Series(spec[0], data, dtype=POLARS_TYPE_MAP[spec[1]]), + pl.Series(spec[0], data, dtype=get_polars_type_map()[spec[1]]), pl.Series(None, mask, dtype=pl.Boolean), ) for (data, mask), spec in zip(cols, colspec) @@ -203,7 +202,7 @@ def load_numpy( return np.asarray(row_ids, dtype=np.longlong), \ [ ( - np.asarray(data, dtype=NUMPY_TYPE_MAP[spec[1]]), # type: ignore + np.asarray(data, dtype=get_numpy_type_map()[spec[1]]), # type: ignore np.asarray(mask, dtype=np.bool_), # type: ignore ) for (data, mask), spec in zip(cols, colspec) @@ -235,7 +234,7 @@ def load_arrow( [ ( pa.array( - data, type=PYARROW_TYPE_MAP[dtype], + data, type=get_pyarrow_type_map()[dtype], mask=pa.array(mask, type=pa.bool_()), ), pa.array(mask, type=pa.bool_()), diff --git a/singlestoredb/functions/ext/rowdat_1.py b/singlestoredb/functions/ext/rowdat_1.py index 94e966b7..c406d1a6 100644 --- a/singlestoredb/functions/ext/rowdat_1.py +++ b/singlestoredb/functions/ext/rowdat_1.py @@ -12,10 +12,9 @@ from ...config import get_option from ...mysql.constants import FIELD_TYPE as ft from ..dtypes import DEFAULT_VALUES -from ..dtypes import NUMPY_TYPE_MAP -from ..dtypes import PANDAS_TYPE_MAP -from ..dtypes import POLARS_TYPE_MAP -from ..dtypes import PYARROW_TYPE_MAP +from ..dtypes import get_numpy_type_map +from ..dtypes import get_polars_type_map +from ..dtypes import get_pyarrow_type_map if TYPE_CHECKING: try: @@ -212,7 +211,7 @@ def _load_pandas( index = pd.Series(row_ids) return pd.Series(row_ids, dtype=np.int64), [ ( - pd.Series(data, index=index, name=name, dtype=PANDAS_TYPE_MAP[dtype]), + pd.Series(data, index=index, name=name, dtype=get_numpy_type_map()[dtype]), pd.Series(mask, index=index, dtype=np.bool_), ) for (data, mask), (name, dtype) in zip(cols, colspec) @@ -247,7 +246,7 @@ def _load_polars( return pl.Series(None, row_ids, dtype=pl.Int64), \ [ ( - pl.Series(name=name, values=data, dtype=POLARS_TYPE_MAP[dtype]), + pl.Series(name=name, values=data, dtype=get_polars_type_map()[dtype]), pl.Series(values=mask, dtype=pl.Boolean), ) for (data, mask), (name, dtype) in zip(cols, colspec) @@ -282,7 +281,7 @@ def _load_numpy( return np.asarray(row_ids, dtype=np.int64), \ [ ( - np.asarray(data, dtype=NUMPY_TYPE_MAP[dtype]), # type: ignore + np.asarray(data, dtype=get_numpy_type_map()[dtype]), # type: ignore np.asarray(mask, dtype=np.bool_), # type: ignore ) for (data, mask), (name, dtype) in zip(cols, colspec) @@ -318,7 +317,7 @@ def _load_arrow( [ ( pa.array( - data, type=PYARROW_TYPE_MAP[dtype], + data, type=get_pyarrow_type_map()[dtype], mask=pa.array(mask, type=pa.bool_()), ), pa.array(mask, type=pa.bool_()), @@ -565,7 +564,7 @@ def _load_pandas_accel( numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( - pd.Series(data, name=name, dtype=PANDAS_TYPE_MAP[dtype]), + pd.Series(data, name=name, dtype=get_numpy_type_map()[dtype]), pd.Series(mask, dtype=np.bool_), ) for (name, dtype), (data, mask) in zip(colspec, numpy_cols) @@ -610,7 +609,7 @@ def _load_polars_accel( pl.Series( name=name, values=data.tolist() if dtype in string_types or dtype in binary_types else data, - dtype=POLARS_TYPE_MAP[dtype], + dtype=get_polars_type_map()[dtype], ), pl.Series(values=mask, dtype=pl.Boolean), ) @@ -653,7 +652,7 @@ def _load_arrow_accel( numpy_ids, numpy_cols = _singlestoredb_accel.load_rowdat_1_numpy(colspec, data) cols = [ ( - pa.array(data, type=PYARROW_TYPE_MAP[dtype], mask=mask), + pa.array(data, type=get_pyarrow_type_map()[dtype], mask=mask), pa.array(mask, type=pa.bool_()), ) for (data, mask), (name, dtype) in zip(numpy_cols, colspec) diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index cf2b5d01..4ec89a48 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -332,7 +332,7 @@ def normalize_dtype(dtype: Any) -> str: if isinstance(dtype, str): return sql_to_dtype(dtype) - if typing.get_origin(dtype) is np.dtype: + if has_numpy and typing.get_origin(dtype) is np.dtype: dtype = typing.get_args(dtype)[0] # Specific types diff --git a/singlestoredb/management/manager.py b/singlestoredb/management/manager.py index 144dbb3e..575df087 100644 --- a/singlestoredb/management/manager.py +++ b/singlestoredb/management/manager.py @@ -10,7 +10,6 @@ from typing import Union from urllib.parse import urljoin -import jwt import requests from .. import config @@ -33,6 +32,7 @@ def set_organization(kwargs: Dict[str, Any]) -> None: def is_jwt(token: str) -> bool: """Is the given token a JWT?""" + import jwt try: jwt.decode(token, options={'verify_signature': False}) return True diff --git a/singlestoredb/management/utils.py b/singlestoredb/management/utils.py index 5aa072e4..ea0e04d2 100644 --- a/singlestoredb/management/utils.py +++ b/singlestoredb/management/utils.py @@ -18,8 +18,6 @@ from typing import Union from urllib.parse import urlparse -import jwt - from .. import converters from ..config import get_option from ..utils import events @@ -151,6 +149,7 @@ def handle_connection_info(msg: Dict[str, Any]) -> None: def retrieve_current_authentication_info() -> List[Tuple[str, Any]]: """Retrieve JWT if not expired.""" + import jwt nonlocal authentication_info password = authentication_info.get('password') if password: @@ -198,6 +197,7 @@ def get_authentication_info(include_env: bool = True) -> Dict[str, Any]: def get_token() -> Optional[str]: """Return the token for the Management API.""" + import jwt # See if an API key is configured tok = get_option('management.token') if tok: diff --git a/singlestoredb/mysql/connection.py b/singlestoredb/mysql/connection.py index 094fad68..2d4fdd2a 100644 --- a/singlestoredb/mysql/connection.py +++ b/singlestoredb/mysql/connection.py @@ -87,8 +87,9 @@ DEFAULT_USER = getpass.getuser() del getpass -except (ImportError, KeyError): +except (ImportError, KeyError, OSError): # KeyError occurs when there's no entry in OS database for a current user. + # OSError occurs in WASM environments where pwd module is unavailable. DEFAULT_USER = None DEBUG = get_option('debug.connection') diff --git a/singlestoredb/tests/test_connection.py b/singlestoredb/tests/test_connection.py index ee392d06..2ae5cf1d 100755 --- a/singlestoredb/tests/test_connection.py +++ b/singlestoredb/tests/test_connection.py @@ -22,8 +22,10 @@ try: import pandas as pd has_pandas = True + _pd_str_dtype = str(pd.DataFrame({'a': ['x']}).dtypes['a']) except ImportError: has_pandas = False + _pd_str_dtype = 'object' class TestConnection(unittest.TestCase): @@ -1124,21 +1126,21 @@ def test_alltypes_pandas(self): ('timestamp', 'datetime64[us]'), ('timestamp_6', 'datetime64[us]'), ('year', 'float64'), - ('char_100', 'object'), + ('char_100', _pd_str_dtype), ('binary_100', 'object'), - ('varchar_200', 'object'), + ('varchar_200', _pd_str_dtype), ('varbinary_200', 'object'), - ('longtext', 'object'), - ('mediumtext', 'object'), - ('text', 'object'), - ('tinytext', 'object'), + ('longtext', _pd_str_dtype), + ('mediumtext', _pd_str_dtype), + ('text', _pd_str_dtype), + ('tinytext', _pd_str_dtype), ('longblob', 'object'), ('mediumblob', 'object'), ('blob', 'object'), ('tinyblob', 'object'), ('json', 'object'), - ('enum', 'object'), - ('set', 'object'), + ('enum', _pd_str_dtype), + ('set', _pd_str_dtype), ('bit', 'object'), ] @@ -1266,21 +1268,21 @@ def test_alltypes_no_nulls_pandas(self): ('timestamp', 'datetime64[us]'), ('timestamp_6', 'datetime64[us]'), ('year', 'int16'), - ('char_100', 'object'), + ('char_100', _pd_str_dtype), ('binary_100', 'object'), - ('varchar_200', 'object'), + ('varchar_200', _pd_str_dtype), ('varbinary_200', 'object'), - ('longtext', 'object'), - ('mediumtext', 'object'), - ('text', 'object'), - ('tinytext', 'object'), + ('longtext', _pd_str_dtype), + ('mediumtext', _pd_str_dtype), + ('text', _pd_str_dtype), + ('tinytext', _pd_str_dtype), ('longblob', 'object'), ('mediumblob', 'object'), ('blob', 'object'), ('tinyblob', 'object'), ('json', 'object'), - ('enum', 'object'), - ('set', 'object'), + ('enum', _pd_str_dtype), + ('set', _pd_str_dtype), ('bit', 'object'), ] diff --git a/singlestoredb/utils/_lazy_import.py b/singlestoredb/utils/_lazy_import.py new file mode 100644 index 00000000..7bc53254 --- /dev/null +++ b/singlestoredb/utils/_lazy_import.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +"""Lazy import utilities for heavy optional dependencies.""" +import importlib +from functools import lru_cache +from typing import Any +from typing import Optional + + +@lru_cache(maxsize=None) +def get_numpy() -> Optional[Any]: + """Return numpy module or None if not installed.""" + try: + return importlib.import_module('numpy') + except ImportError: + return None + + +@lru_cache(maxsize=None) +def get_pandas() -> Optional[Any]: + """Return pandas module or None if not installed.""" + try: + return importlib.import_module('pandas') + except ImportError: + return None + + +@lru_cache(maxsize=None) +def get_polars() -> Optional[Any]: + """Return polars module or None if not installed.""" + try: + return importlib.import_module('polars') + except ImportError: + return None + + +@lru_cache(maxsize=None) +def get_pyarrow() -> Optional[Any]: + """Return pyarrow module or None if not installed.""" + try: + return importlib.import_module('pyarrow') + except ImportError: + return None diff --git a/singlestoredb/utils/dtypes.py b/singlestoredb/utils/dtypes.py index 73eb893c..e343110d 100644 --- a/singlestoredb/utils/dtypes.py +++ b/singlestoredb/utils/dtypes.py @@ -1,22 +1,11 @@ #!/usr/bin/env python3 +from functools import lru_cache +from typing import Any +from typing import Dict -try: - import numpy as np - has_numpy = True -except ImportError: - has_numpy = False - -try: - import polars as pl - has_polars = True -except ImportError: - has_polars = False - -try: - import pyarrow as pa - has_pyarrow = True -except ImportError: - has_pyarrow = False +from ._lazy_import import get_numpy +from ._lazy_import import get_polars +from ._lazy_import import get_pyarrow DEFAULT_VALUES = { @@ -64,8 +53,13 @@ } -if has_numpy: - NUMPY_TYPE_MAP = { +@lru_cache(maxsize=None) +def get_numpy_type_map() -> Dict[int, Any]: + """Return numpy type map, or empty dict if numpy is not installed.""" + np = get_numpy() + if np is None: + return {} + return { 0: object, # Decimal 1: np.int8, # Tiny -1: np.uint8, # Unsigned Tiny @@ -107,13 +101,15 @@ -254: object, # Binary 255: object, # Geometry } -else: - NUMPY_TYPE_MAP = {} -PANDAS_TYPE_MAP = NUMPY_TYPE_MAP -if has_pyarrow: - PYARROW_TYPE_MAP = { +@lru_cache(maxsize=None) +def get_pyarrow_type_map() -> Dict[int, Any]: + """Return pyarrow type map, or empty dict if pyarrow is not installed.""" + pa = get_pyarrow() + if pa is None: + return {} + return { 0: pa.decimal128(18, 6), # Decimal 1: pa.int8(), # Tiny -1: pa.uint8(), # Unsigned Tiny @@ -155,11 +151,15 @@ -254: pa.binary(), # Binary 255: pa.string(), # Geometry } -else: - PYARROW_TYPE_MAP = {} -if has_polars: - POLARS_TYPE_MAP = { + +@lru_cache(maxsize=None) +def get_polars_type_map() -> Dict[int, Any]: + """Return polars type map, or empty dict if polars is not installed.""" + pl = get_polars() + if pl is None: + return {} + return { 0: pl.Decimal(10, 6), # Decimal 1: pl.Int8, # Tiny -1: pl.UInt8, # Unsigned Tiny @@ -201,5 +201,3 @@ -254: pl.Binary, # Binary 255: pl.Utf8, # Geometry } -else: - POLARS_TYPE_MAP = {} diff --git a/singlestoredb/utils/events.py b/singlestoredb/utils/events.py index dab01f08..1a0c6644 100644 --- a/singlestoredb/utils/events.py +++ b/singlestoredb/utils/events.py @@ -7,7 +7,7 @@ try: from IPython import get_ipython has_ipython = True -except ImportError: +except (ImportError, OSError): has_ipython = False diff --git a/singlestoredb/utils/results.py b/singlestoredb/utils/results.py index 83846571..3264bd5f 100644 --- a/singlestoredb/utils/results.py +++ b/singlestoredb/utils/results.py @@ -2,6 +2,7 @@ """SingleStoreDB package utilities.""" import collections import warnings +from functools import lru_cache from typing import Any from typing import Callable from typing import Dict @@ -9,47 +10,34 @@ from typing import NamedTuple from typing import Optional from typing import Tuple +from typing import TYPE_CHECKING from typing import Union -from .dtypes import NUMPY_TYPE_MAP -from .dtypes import POLARS_TYPE_MAP -from .dtypes import PYARROW_TYPE_MAP +if TYPE_CHECKING: + import numpy + import pandas + import polars + import pyarrow + +from ._lazy_import import get_numpy +from ._lazy_import import get_pandas +from ._lazy_import import get_polars +from ._lazy_import import get_pyarrow +from .dtypes import get_numpy_type_map +from .dtypes import get_polars_type_map +from .dtypes import get_pyarrow_type_map UNSIGNED_FLAG = 32 BINARY_FLAG = 128 -try: - has_numpy = True - import numpy as np -except ImportError: - has_numpy = False - -try: - has_pandas = True - import pandas as pd -except ImportError: - has_pandas = False - -try: - has_polars = True - import polars as pl -except ImportError: - has_polars = False - -try: - has_pyarrow = True - import pyarrow as pa -except ImportError: - has_pyarrow = False - DBAPIResult = Union[List[Tuple[Any, ...]], Tuple[Any, ...]] OneResult = Union[ Tuple[Any, ...], Dict[str, Any], - 'np.ndarray', 'pd.DataFrame', 'pl.DataFrame', 'pa.Table', + 'numpy.ndarray', 'pandas.DataFrame', 'polars.DataFrame', 'pyarrow.Table', ] ManyResult = Union[ List[Tuple[Any, ...]], List[Dict[str, Any]], - 'np.ndarray', 'pd.DataFrame', 'pl.DataFrame', 'pa.Table', + 'numpy.ndarray', 'pandas.DataFrame', 'polars.DataFrame', 'pyarrow.Table', ] Result = Union[OneResult, ManyResult] @@ -67,11 +55,14 @@ class Description(NamedTuple): charset: Optional[int] -if has_numpy: - # If an int column is nullable, we need to use floats rather than - # ints for numpy and pandas. - NUMPY_TYPE_MAP_CAST_FLOAT = NUMPY_TYPE_MAP.copy() - NUMPY_TYPE_MAP_CAST_FLOAT.update({ +@lru_cache(maxsize=None) +def _get_numpy_type_map_cast_float() -> Dict[int, Any]: + """Return numpy type map with int types cast to float for nullable columns.""" + np = get_numpy() + if np is None: + return {} + type_map = get_numpy_type_map().copy() + type_map.update({ 1: np.float32, # Tiny -1: np.float32, # Unsigned Tiny 2: np.float32, # Short @@ -84,15 +75,23 @@ class Description(NamedTuple): -9: np.float64, # Unsigned Int24 13: np.float64, # Year }) + return type_map -if has_polars: + +@lru_cache(maxsize=None) +def _get_polars_type_map_with_dates() -> Dict[int, Any]: + """Return polars type map with date/times remapped to strings.""" + pl = get_polars() + if pl is None: + return {} + type_map = get_polars_type_map().copy() # Remap date/times to strings; let polars do the parsing - POLARS_TYPE_MAP = POLARS_TYPE_MAP.copy() - POLARS_TYPE_MAP.update({ + type_map.update({ 7: pl.Utf8, 10: pl.Utf8, 12: pl.Utf8, }) + return type_map INT_TYPES = set([1, 2, 3, 8, 9]) @@ -109,13 +108,15 @@ def signed(desc: Description) -> int: def _description_to_numpy_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to numpy array schema info.""" - if has_numpy: + if get_numpy() is not None: + numpy_type_map = get_numpy_type_map() + numpy_type_map_cast_float = _get_numpy_type_map_cast_float() return dict( dtype=[ ( x.name, - NUMPY_TYPE_MAP_CAST_FLOAT[signed(x)] - if x.null_ok else NUMPY_TYPE_MAP[signed(x)], + numpy_type_map_cast_float[signed(x)] + if x.null_ok else numpy_type_map[signed(x)], ) for x in desc ], @@ -125,18 +126,21 @@ def _description_to_numpy_schema(desc: List[Description]) -> Dict[str, Any]: def _description_to_pandas_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to pandas DataFrame schema info.""" - if has_pandas: + if get_pandas() is not None: return dict(columns=[x.name for x in desc]) return {} -def _decimalize_polars(desc: Description) -> 'pl.Decimal': - return pl.Decimal(desc.precision or 10, desc.scale or 0) +def _decimalize_polars(desc: Description) -> Any: + pl = get_polars() + return pl.Decimal(desc.precision or 10, desc.scale or 0) # type: ignore[union-attr] def _description_to_polars_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to polars DataFrame schema info.""" - if has_polars: + pl = get_polars() + if pl is not None: + polars_type_map = _get_polars_type_map_with_dates() with_columns = {} for x in desc: if x.type_code in [7, 12]: @@ -156,7 +160,8 @@ def _description_to_polars_schema(desc: List[Description]) -> Dict[str, Any]: schema=[ ( x.name, _decimalize_polars(x) - if x.type_code in DECIMAL_TYPES else POLARS_TYPE_MAP[signed(x)], + if x.type_code in DECIMAL_TYPES + else polars_type_map[signed(x)], ) for x in desc ], @@ -166,18 +171,24 @@ def _description_to_polars_schema(desc: List[Description]) -> Dict[str, Any]: return {} -def _decimalize_arrow(desc: Description) -> 'pa.Decimal128': - return pa.decimal128(desc.precision or 10, desc.scale or 0) +def _decimalize_arrow(desc: Description) -> Any: + pa = get_pyarrow() + return pa.decimal128( # type: ignore[union-attr] + desc.precision or 10, desc.scale or 0, + ) def _description_to_arrow_schema(desc: List[Description]) -> Dict[str, Any]: """Convert description to Arrow Table schema info.""" - if has_pyarrow: + pa = get_pyarrow() + if pa is not None: + pyarrow_type_map = get_pyarrow_type_map() return dict( schema=pa.schema([ ( x.name, _decimalize_arrow(x) - if x.type_code in DECIMAL_TYPES else PYARROW_TYPE_MAP[signed(x)], + if x.type_code in DECIMAL_TYPES + else pyarrow_type_map[signed(x)], ) for x in desc ]), @@ -215,7 +226,8 @@ def results_to_numpy( """ if not res: return res - if has_numpy: + np = get_numpy() + if np is not None: schema = _description_to_numpy_schema(desc) if schema is None else schema if single: return np.array([res], **schema) @@ -257,7 +269,8 @@ def results_to_pandas( """ if not res: return res - if has_pandas: + pd = get_pandas() + if pd is not None: schema = _description_to_pandas_schema(desc) if schema is None else schema return pd.DataFrame(results_to_numpy(desc, res, single=single, schema=schema)) warnings.warn( @@ -297,7 +310,8 @@ def results_to_polars( """ if not res: return res - if has_polars: + pl = get_polars() + if pl is not None: schema = _description_to_polars_schema(desc) if schema is None else schema if single: out = pl.DataFrame([res], orient='row', **schema.get('schema', {})) @@ -344,7 +358,8 @@ def results_to_arrow( """ if not res: return res - if has_pyarrow: + pa = get_pyarrow() + if pa is not None: names = [x[0] for x in desc] schema = _description_to_arrow_schema(desc) if schema is None else schema if single: diff --git a/wit/udf.wit b/wit/udf.wit new file mode 100644 index 00000000..6362b5d6 --- /dev/null +++ b/wit/udf.wit @@ -0,0 +1,26 @@ +package singlestore:udf; + +interface function-handler { + /// Initialize the handler and discover pre-imported UDF modules. + /// Must be called before any other functions. + initialize: func() -> result<_, string>; + + /// Call a function by its registered name (e.g., "my_func") + /// Input/output are rowdat_1 binary format + call-function: func(name: string, input: list) -> result, string>; + + /// Describe all registered functions as a JSON array of objects. + /// Each object has: name, args [{name, dtype, sql}], returns [{name, dtype, sql}], + /// args_data_format, returns_data_format, function_type, doc + describe-functions: func() -> result; + + /// Register a function from its signature and source code. + /// `signature` is a JSON object matching the describe-functions element schema. + /// `code` is the function body (not a full `def` statement). + /// `replace` controls whether an existing function of the same name is overwritten. + create-function: func(signature: string, code: string, replace: bool) -> result<_, string>; +} + +world external-udf { + export function-handler; +}