Skip to content

Commit a8a0025

Browse files
author
peng.li24
committed
feat: add scalar (single-value) overloads for all element-wise functions
Add numpy::sqrt(x), numpy::sin(x), numpy::exp(x), ... scalar overloads inside namespace numpy, coexisting with the existing array API (different argument count / types, no ambiguity). Every scalar overload delegates to the public array API via &x, &x, 1 — never to detail:: — so AVX-512 specialisations and future array-level optimisations are automatically inherited. Unary (T → T): sqrt abs exp log sin cos tan cbrt expm1 log1p log10 log2 arcsin arccos arctan round floor ceil degrees radians sign Binary (T, T → T): power(x,e) hypot(x,y) arctan2(y,x) maximum(a,b) minimum(a,b) Ternary (T, T, T → T): clip(x, lo, hi) Usage: double r = numpy::sqrt(2.0); float s = numpy::sin(x); double n = numpy::hypot(3.0, 4.0); double c = numpy::clip(val, lo, hi); Both bit-exact and std builds verified.
1 parent c2a260f commit a8a0025

1 file changed

Lines changed: 52 additions & 0 deletions

File tree

numpy/elementwise.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,58 @@ inline void truncate_to_float32(const double* src, double* dst, size_t n) {
358358
dst[i] = static_cast<double>(tmp); });
359359
}
360360

361+
// ============================================================================
362+
// Scalar (single-value) overloads
363+
//
364+
// Parallel to every array API but taking a single T and returning T.
365+
// Call site: double r = numpy::sqrt(x); float r = numpy::sin(x);
366+
//
367+
// Unary math — delegate to detail:: (SVML-bridge or std, same accuracy):
368+
// sqrt abs exp log sin cos tan cbrt expm1 log1p
369+
// log10 log2 arcsin arccos arctan round floor ceil degrees radians sign
370+
//
371+
// Binary — two scalars in, one scalar out:
372+
// power(x,e) hypot(x,y) arctan2(y,x) maximum(a,b) minimum(a,b)
373+
//
374+
// Ternary: clip(x, lo, hi)
375+
// ============================================================================
376+
377+
// ── Unary — route through the array API (inherits AVX-512 specialisations) ─
378+
379+
template<typename T> inline T sqrt (T x) { sqrt (&x, &x, 1); return x; }
380+
template<typename T> inline T abs (T x) { abs (&x, &x, 1); return x; }
381+
template<typename T> inline T exp (T x) { exp (&x, &x, 1); return x; }
382+
template<typename T> inline T log (T x) { log (&x, &x, 1); return x; }
383+
template<typename T> inline T sin (T x) { sin (&x, &x, 1); return x; }
384+
template<typename T> inline T cos (T x) { cos (&x, &x, 1); return x; }
385+
template<typename T> inline T tan (T x) { tan (&x, &x, 1); return x; }
386+
template<typename T> inline T cbrt (T x) { cbrt (&x, &x, 1); return x; }
387+
template<typename T> inline T expm1 (T x) { expm1 (&x, &x, 1); return x; }
388+
template<typename T> inline T log1p (T x) { log1p (&x, &x, 1); return x; }
389+
template<typename T> inline T log10 (T x) { log10 (&x, &x, 1); return x; }
390+
template<typename T> inline T log2 (T x) { log2 (&x, &x, 1); return x; }
391+
template<typename T> inline T arcsin (T x) { arcsin (&x, &x, 1); return x; }
392+
template<typename T> inline T arccos (T x) { arccos (&x, &x, 1); return x; }
393+
template<typename T> inline T arctan (T x) { arctan (&x, &x, 1); return x; }
394+
template<typename T> inline T round (T x) { round (&x, &x, 1); return x; }
395+
template<typename T> inline T floor (T x) { floor (&x, &x, 1); return x; }
396+
template<typename T> inline T ceil (T x) { ceil (&x, &x, 1); return x; }
397+
template<typename T> inline T degrees(T x) { degrees(&x, &x, 1); return x; }
398+
template<typename T> inline T radians(T x) { radians(&x, &x, 1); return x; }
399+
template<typename T> inline T sign (T x) { sign (&x, &x, 1); return x; }
400+
401+
// ── Binary ─────────────────────────────────────────────────────────────────
402+
403+
template<typename T> inline T power (T x, T e) { power(&x, &x, 1, e); return x; }
404+
template<typename T> inline T hypot (T x, T y) { T r; hypot (&x,&y,&r,1); return r; }
405+
template<typename T> inline T arctan2(T y, T x) { T r; arctan2(&y,&x,&r,1); return r; }
406+
template<typename T> inline T maximum(T a, T b) { T r; maximum(&a,&b,&r,1); return r; }
407+
template<typename T> inline T minimum(T a, T b) { T r; minimum(&a,&b,&r,1); return r; }
408+
409+
// ── Ternary ────────────────────────────────────────────────────────────────
410+
411+
template<typename T> inline T clip(T x, T lo, T hi) { clip(&x, &x, 1, lo, hi); return x; }
412+
361413
// ============================================================================
362414
// AVX-512 wide-loop template specialisations (0 ULP, ~8-16x faster)
363415
// Must appear inside namespace numpy after all primary templates.

0 commit comments

Comments
 (0)