Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 201 additions & 29 deletions src/csrc/umath/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern "C" {
#include "scalar.h"
#include "dtype.h"
#include "ops.hpp"
#include "constants.hpp"
#include "umath/matmul.h"
#include "umath/promoters.hpp"
#include "quadblas_interface.h"
Expand All @@ -32,25 +33,23 @@ quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1];

// QBLAS only supports SLEEF backend
static const char *non_sleef_msg =
"QBLAS-accelerated quad ufuncs only support the SLEEF backend. "
"Please raise the issue at SwayamInSync/QBLAS for longdouble support";

if (descr_in1->backend != BACKEND_SLEEF || descr_in2->backend != BACKEND_SLEEF) {
PyErr_SetString(PyExc_NotImplementedError,
"QBLAS-accelerated matmul only supports SLEEF backend. "
"Please raise the issue at SwayamInSync/QBLAS for longdouble support");
PyErr_SetString(PyExc_NotImplementedError, non_sleef_msg);
return (NPY_CASTING)-1;
}

// Both inputs must use SLEEF backend
QuadBackendType target_backend = BACKEND_SLEEF;
NPY_CASTING casting = NPY_NO_CASTING;

// Set up input descriptors
for (int i = 0; i < 2; i++) {
Py_INCREF(given_descrs[i]);
loop_descrs[i] = given_descrs[i];
}

// Set up output descriptor
if (given_descrs[2] == NULL) {
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
if (!loop_descrs[2]) {
Expand All @@ -60,9 +59,7 @@ quad_matmul_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[
else {
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2];
if (descr_out->backend != target_backend) {
PyErr_SetString(PyExc_NotImplementedError,
"QBLAS-accelerated matmul only supports SLEEF backend. "
"Please raise the issue at SwayamInSync/QBLAS for longdouble support");
PyErr_SetString(PyExc_NotImplementedError, non_sleef_msg);
return (NPY_CASTING)-1;
}
else {
Expand Down Expand Up @@ -126,8 +123,8 @@ quad_matmul_strided_loop_aligned(PyArrayMethod_Context *context, char *const dat
}

MatmulOperationType op_type = determine_operation_type(m, n, p);
Sleef_quad alpha = Sleef_cast_from_doubleq1(1.0);
Sleef_quad beta = Sleef_cast_from_doubleq1(0.0);
Sleef_quad alpha = QUAD_PRECISION_ONE;
Sleef_quad beta = QUAD_PRECISION_ZERO;

char *A = data[0];
char *B = data[1];
Expand Down Expand Up @@ -216,8 +213,8 @@ quad_matmul_strided_loop_unaligned(PyArrayMethod_Context *context, char *const d
}

MatmulOperationType op_type = determine_operation_type(m, n, p);
Sleef_quad alpha = Sleef_cast_from_doubleq1(1.0);
Sleef_quad beta = Sleef_cast_from_doubleq1(0.0);
Sleef_quad alpha = QUAD_PRECISION_ONE;
Sleef_quad beta = QUAD_PRECISION_ZERO;

char *A = data[0];
char *B = data[1];
Expand Down Expand Up @@ -315,6 +312,157 @@ quad_matmul_strided_loop_unaligned(PyArrayMethod_Context *context, char *const d
return 0;
}

// vecdot: signature (n),(n)->()

static int
quad_vecdot_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[], npy_intp const strides[],
NpyAuxData *auxdata)
{
npy_intp N = dimensions[0]; // outer (broadcast) loop length
npy_intp n = dimensions[1]; // core dim length

npy_intp x_outer_stride = strides[0];
npy_intp y_outer_stride = strides[1];
npy_intp out_outer_stride = strides[2];
npy_intp x_n_stride = strides[3];
npy_intp y_n_stride = strides[4];

char *x = data[0];
char *y = data[1];
char *out = data[2];

size_t incx = x_n_stride / sizeof(Sleef_quad);
size_t incy = y_n_stride / sizeof(Sleef_quad);

for (npy_intp i = 0; i < N; i++) {
Sleef_quad *x_ptr = (Sleef_quad *)x;
Sleef_quad *y_ptr = (Sleef_quad *)y;
Sleef_quad *out_ptr = (Sleef_quad *)out;

if (n == 0) {
*out_ptr = QUAD_PRECISION_ZERO;
}
else {
int result = qblas_dot(n, x_ptr, incx, y_ptr, incy, out_ptr);
if (result != 0) {
PyErr_SetString(PyExc_RuntimeError, "QBLAS vecdot operation failed");
return -1;
}
}

x += x_outer_stride;
y += y_outer_stride;
out += out_outer_stride;
}

return 0;
}

static int
quad_vecdot_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[], npy_intp const strides[],
NpyAuxData *auxdata)
{
npy_intp N = dimensions[0];
npy_intp n = dimensions[1];

npy_intp x_outer_stride = strides[0];
npy_intp y_outer_stride = strides[1];
npy_intp out_outer_stride = strides[2];

char *x = data[0];
char *y = data[1];
char *out = data[2];

if (n == 0) {
Sleef_quad zero = QUAD_PRECISION_ZERO;
for (npy_intp i = 0; i < N; i++) {
memcpy(out + i * out_outer_stride, &zero, sizeof(Sleef_quad));
}
return 0;
}

Sleef_quad *temp_x = new Sleef_quad[n];
Sleef_quad *temp_y = new Sleef_quad[n];

int result = 0;
for (npy_intp i = 0; i < N; i++) {
memcpy(temp_x, x, n * sizeof(Sleef_quad));
memcpy(temp_y, y, n * sizeof(Sleef_quad));

Sleef_quad sum;
result = qblas_dot(n, temp_x, 1, temp_y, 1, &sum);
if (result != 0) break;

memcpy(out, &sum, sizeof(Sleef_quad));

x += x_outer_stride;
y += y_outer_stride;
out += out_outer_stride;
}

delete[] temp_x;
delete[] temp_y;

if (result != 0) {
PyErr_SetString(PyExc_RuntimeError, "QBLAS vecdot operation failed");
return -1;
}
return 0;
}

static int
naive_vecdot_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[], npy_intp const strides[],
NpyAuxData *auxdata)
{
npy_intp N = dimensions[0];
npy_intp n = dimensions[1];

npy_intp x_outer_stride = strides[0];
npy_intp y_outer_stride = strides[1];
npy_intp out_outer_stride = strides[2];
npy_intp x_n_stride = strides[3];
npy_intp y_n_stride = strides[4];

QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
QuadBackendType backend = descr->backend;

char *x = data[0];
char *y = data[1];
char *out = data[2];

for (npy_intp i = 0; i < N; i++) {
if (backend == BACKEND_SLEEF) {
Sleef_quad sum = QUAD_PRECISION_ZERO;
for (npy_intp k = 0; k < n; k++) {
Sleef_quad a_val, b_val;
memcpy(&a_val, x + k * x_n_stride, sizeof(Sleef_quad));
memcpy(&b_val, y + k * y_n_stride, sizeof(Sleef_quad));
sum = Sleef_fmaq1_u05(a_val, b_val, sum);
}
memcpy(out, &sum, sizeof(Sleef_quad));
}
else {
long double sum = 0.0L;
for (npy_intp k = 0; k < n; k++) {
long double a_val, b_val;
memcpy(&a_val, x + k * x_n_stride, sizeof(long double));
memcpy(&b_val, y + k * y_n_stride, sizeof(long double));
sum += a_val * b_val;
}
memcpy(out, &sum, sizeof(long double));
}

x += x_outer_stride;
y += y_outer_stride;
out += out_outer_stride;
}

return 0;
}

static int
naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[], npy_intp const strides[],
Expand Down Expand Up @@ -349,7 +497,7 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
char *C_ij = C + i * C_row_stride + j * C_col_stride;

if (backend == BACKEND_SLEEF) {
Sleef_quad sum = Sleef_cast_from_doubleq1(0.0);
Sleef_quad sum = QUAD_PRECISION_ZERO;

for (npy_intp k = 0; k < n; k++) {
char *A_ik = A + i * A_row_stride + k * A_col_stride;
Expand Down Expand Up @@ -385,32 +533,26 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
return 0;
}

int
init_matmul_ops(PyObject *numpy)
static int
register_matmul_like_ufunc(PyObject *numpy, const char *ufunc_name, const char *spec_name,
PyArrayMethod_StridedLoop *aligned_loop,
PyArrayMethod_StridedLoop *unaligned_loop)
{
PyObject *ufunc = PyObject_GetAttrString(numpy, "matmul");
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
if (ufunc == NULL) {
return -1;
}

PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};

#ifndef DISABLE_QUADBLAS

PyType_Slot slots[] = {
{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
Comment thread
ngoldbaum marked this conversation as resolved.
{NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop_aligned},
{NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop_unaligned},
{NPY_METH_strided_loop, (void *)aligned_loop},
{NPY_METH_unaligned_strided_loop, (void *)unaligned_loop},
{0, NULL}};
#else
PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
{NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop},
{NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
{0, NULL}};
#endif // DISABLE_QUADBLAS

PyArrayMethod_Spec Spec = {
.name = "quad_matmul_qblas",
.name = spec_name,
.nin = 2,
.nout = 1,
.casting = NPY_NO_CASTING,
Expand Down Expand Up @@ -460,5 +602,35 @@ init_matmul_ops(PyObject *numpy)
Py_DECREF(promoter_capsule);
Py_DECREF(ufunc);

return 0;
}

int
init_matmul_ops(PyObject *numpy)
{
#ifndef DISABLE_QUADBLAS
if (register_matmul_like_ufunc(numpy, "matmul", "quad_matmul_qblas",
&quad_matmul_strided_loop_aligned,
&quad_matmul_strided_loop_unaligned) < 0) {
return -1;
}
if (register_matmul_like_ufunc(numpy, "vecdot", "quad_vecdot_qblas",
&quad_vecdot_strided_loop_aligned,
&quad_vecdot_strided_loop_unaligned) < 0) {
return -1;
}
#else
if (register_matmul_like_ufunc(numpy, "matmul", "quad_matmul_naive",
&naive_matmul_strided_loop,
&naive_matmul_strided_loop) < 0) {
return -1;
}
if (register_matmul_like_ufunc(numpy, "vecdot", "quad_vecdot_naive",
&naive_vecdot_strided_loop,
&naive_vecdot_strided_loop) < 0) {
return -1;
}
#endif // DISABLE_QUADBLAS

return 0;
}
Loading
Loading