11// SVML/npy bridge — resolve functions from numpy's _multiarray_umath.so.
22//
3- // numpy uses SVML for ALL transcendental functions when available:
4- // f64: __svml_exp8, __svml_log8, __svml_sin8, __svml_cos8, __svml_tan8,
5- // __svml_asin8, __svml_acos8, __svml_atan8, __svml_log108, __svml_log28,
6- // __svml_exp28, npy_pow, npy_atan2
7- // f32: __svml_expf16, __svml_logf16, __svml_sinf16, __svml_cosf16, __svml_tanf16,
8- // __svml_asinf16, __svml_acosf16, __svml_atanf16, __svml_log10f16, __svml_log2f16,
9- // __svml_exp2f16, npy_pow, npy_atan2
10- //
3+ // numpy uses SVML for ALL transcendental functions when available.
114// Call bridge_init(path_to_multiarray_umath_so) before first use.
125
136#pragma once
@@ -42,73 +35,48 @@ inline void* resolve_svml(const char* name) {
4235 fprintf (stderr, " [SVML] resolve(%s) -> %p dlerror=%s\n " , name, ptr, dlerror ());
4336 return ptr;
4437 }
45- fprintf (stderr, " [SVML] resolve(%s) -> NO HANDLE (g_svml_handle is null) \n " , name);
38+ fprintf (stderr, " [SVML] resolve(%s) -> NO HANDLE\n " , name);
4639 return nullptr ;
4740}
4841
49- #ifdef __AVX512F__
50-
5142// ============================================================================
52- // float64 — SVML for most, npy_pow / npy_atan2 for pow/atan2
43+ // SVML wrapper macros — consolidate repetitive f64/f32 patterns
5344// ============================================================================
5445
55- inline double exp_f64 (double x) {
56- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_exp8" );
57- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
58- return std::exp (x);
59- }
60- inline double log_f64 (double x) {
61- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_log8" );
62- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
63- return std::log (x);
64- }
65- inline double sin_f64 (double x) {
66- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_sin8" );
67- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
68- return std::sin (x);
69- }
70- inline double cos_f64 (double x) {
71- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_cos8" );
72- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
73- return std::cos (x);
74- }
75- inline double tan_f64 (double x) {
76- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_tan8" );
77- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
78- return std::tan (x);
79- }
80- inline double asin_f64 (double x) {
81- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_asin8" );
82- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
83- return std::asin (x);
84- }
85- inline double acos_f64 (double x) {
86- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_acos8" );
87- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
88- return std::acos (x);
89- }
90- inline double atan_f64 (double x) {
91- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_atan8" );
92- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
93- return std::atan (x);
94- }
95- inline double log10_f64 (double x) {
96- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_log108" );
97- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
98- return std::log10 (x);
99- }
100- inline double log2_f64 (double x) {
101- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_log28" );
102- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
103- return std::log2 (x);
104- }
105- inline double exp2_f64 (double x) {
106- static auto fn = (__m512d (*)(__m512d))resolve_svml (" __svml_exp28" );
107- if (fn) return _mm512_cvtsd_f64 (fn (_mm512_set1_pd (x)));
108- return std::exp2 (x);
109- }
46+ #ifdef __AVX512F__
11047
111- // numpy uses npy_pow / npy_atan2 (scalar C functions, not SVML vector)
48+ // 1-arg f64: __svml_XXX8 → __m512d → extract scalar
49+ // 1-arg f32: __svml_XXXf16 → __m512 → extract scalar
50+ #define NUMPY_SVML_F64 (name, sym, fallback ) \
51+ inline double name (double x) { \
52+ static auto fn = (__m512d (*)(__m512d)) \
53+ resolve_svml (sym); \
54+ if (fn) return _mm512_cvtsd_f64 ( \
55+ fn (_mm512_set1_pd (x))); \
56+ return fallback (x); \
57+ }
58+ #define NUMPY_SVML_F32 (name, sym, fallback ) \
59+ inline float name (float x) { \
60+ static auto fn = (__m512 (*)(__m512)) \
61+ resolve_svml (sym); \
62+ if (fn) return _mm512_cvtss_f32 ( \
63+ fn (_mm512_set1_ps (x))); \
64+ return fallback (x); \
65+ }
66+
67+ NUMPY_SVML_F64 (exp_f64, " __svml_exp8" , std::exp)
68+ NUMPY_SVML_F64 (log_f64, " __svml_log8" , std::log)
69+ NUMPY_SVML_F64 (sin_f64, " __svml_sin8" , std::sin)
70+ NUMPY_SVML_F64 (cos_f64, " __svml_cos8" , std::cos)
71+ NUMPY_SVML_F64 (tan_f64, " __svml_tan8" , std::tan)
72+ NUMPY_SVML_F64 (asin_f64, " __svml_asin8" , std::asin)
73+ NUMPY_SVML_F64 (acos_f64, " __svml_acos8" , std::acos)
74+ NUMPY_SVML_F64 (atan_f64, " __svml_atan8" , std::atan)
75+ NUMPY_SVML_F64 (log10_f64," __svml_log108" ,std::log10)
76+ NUMPY_SVML_F64 (log2_f64, " __svml_log28" , std::log2)
77+ NUMPY_SVML_F64 (exp2_f64, " __svml_exp28" , std::exp2)
78+
79+ // pow/atan2 — 2-arg scalar (npy_pow / npy_atan2)
11280inline double pow_f64 (double x, double e) {
11381 static auto fn = (double (*)(double , double ))resolve_svml (" npy_pow" );
11482 if (fn) return fn (x, e);
@@ -120,131 +88,108 @@ inline double atan2_f64(double y, double x) {
12088 return std::atan2 (y, x);
12189}
12290
123- // ============================================================================
124- // float32 — SVML for ALL transcendental functions
125- // ============================================================================
126-
127- // exp/log/sin/cos — use numpy's own polynomial approximations for bit-exact results
91+ // f32: exp/log/sin/cos use numpy's own polynomial approximations
12892inline float exp_f32 (float x) { return npy_float_math::npy_expf (x); }
12993inline float log_f32 (float x) { return npy_float_math::npy_logf (x); }
13094inline float sin_f32 (float x) { return npy_float_math::npy_sinf (x); }
13195inline float cos_f32 (float x) { return npy_float_math::npy_cosf (x); }
132- inline float tan_f32 (float x) {
133- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_tanf16" );
134- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
135- return std::tan (x);
136- }
137- inline float asin_f32 (float x) {
138- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_asinf16" );
139- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
140- return std::asin (x);
141- }
142- inline float acos_f32 (float x) {
143- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_acosf16" );
144- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
145- return std::acos (x);
146- }
147- inline float atan_f32 (float x) {
148- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_atanf16" );
149- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
150- return std::atan (x);
151- }
152- inline float log10_f32 (float x) {
153- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_log10f16" );
154- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
155- return std::log10 (x);
156- }
157- inline float log2_f32 (float x) {
158- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_log2f16" );
159- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
160- return std::log2 (x);
161- }
162- inline float exp2_f32 (float x) {
163- static auto fn = (__m512 (*)(__m512))resolve_svml (" __svml_exp2f16" );
164- if (fn) return _mm512_cvtss_f32 (fn (_mm512_set1_ps (x)));
165- return std::exp2 (x);
166- }
167- inline float pow_f32 (float x, float e) { return std::pow (x, e); }
96+
97+ NUMPY_SVML_F32 (tan_f32, " __svml_tanf16" , std::tan)
98+ NUMPY_SVML_F32 (asin_f32, " __svml_asinf16" , std::asin)
99+ NUMPY_SVML_F32 (acos_f32, " __svml_acosf16" , std::acos)
100+ NUMPY_SVML_F32 (atan_f32, " __svml_atanf16" , std::atan)
101+ NUMPY_SVML_F32 (log10_f32," __svml_log10f16" ,std::log10)
102+ NUMPY_SVML_F32 (log2_f32, " __svml_log2f16" , std::log2)
103+ NUMPY_SVML_F32 (exp2_f32, " __svml_exp2f16" , std::exp2)
104+
105+ inline float pow_f32 (float x, float e) { return std::pow (x, e); }
168106inline float atan2_f32 (float y, float x) { return std::atan2 (y, x); }
169107
108+ #undef NUMPY_SVML_F64
109+ #undef NUMPY_SVML_F32
110+
170111#else // !__AVX512F__
171112
172- inline double exp_f64 (double x) { return std::exp (x); }
173- inline double log_f64 (double x) { return std::log (x); }
174- inline double sin_f64 (double x) { return std::sin (x); }
175- inline double cos_f64 (double x) { return std::cos (x); }
176- inline double tan_f64 (double x) { return std::tan (x); }
177- inline double asin_f64 (double x) { return std::asin (x); }
178- inline double acos_f64 (double x) { return std::acos (x); }
179- inline double atan_f64 (double x) { return std::atan (x); }
180- inline double log10_f64 (double x){ return std::log10 (x); }
181- inline double log2_f64 (double x) { return std::log2 (x); }
182- inline double exp2_f64 (double x) { return std::exp2 (x); }
183- inline double pow_f64 (double x, double e) { return std::pow (x, e); }
113+ // Non-AVX512: all 1-arg wrappers degrade to std or npy_math
114+ #define NUMPY_FB_F64 (name, fallback ) inline double name (double x) { return fallback (x); }
115+ #define NUMPY_FB_F32 (name, fallback ) inline float name (float x) { return fallback (x); }
116+
117+ NUMPY_FB_F64 (exp_f64, std::exp)
118+ NUMPY_FB_F64 (log_f64, std::log)
119+ NUMPY_FB_F64 (sin_f64, std::sin)
120+ NUMPY_FB_F64 (cos_f64, std::cos)
121+ NUMPY_FB_F64 (tan_f64, std::tan)
122+ NUMPY_FB_F64 (asin_f64, std::asin)
123+ NUMPY_FB_F64 (acos_f64, std::acos)
124+ NUMPY_FB_F64 (atan_f64, std::atan)
125+ NUMPY_FB_F64 (log10_f64,std::log10)
126+ NUMPY_FB_F64 (log2_f64, std::log2)
127+ NUMPY_FB_F64 (exp2_f64, std::exp2)
128+ inline double pow_f64 (double x, double e) { return std::pow (x, e); }
184129inline double atan2_f64 (double y, double x) { return std::atan2 (y, x); }
185- inline float exp_f32 (float x) { return npy_float_math::npy_expf (x); }
186- inline float log_f32 (float x) { return npy_float_math::npy_logf (x); }
187- inline float sin_f32 (float x) { return npy_float_math::npy_sinf (x); }
188- inline float cos_f32 (float x) { return npy_float_math::npy_cosf (x); }
189- inline float tan_f32 (float x) { return std::tan (x); }
190- inline float asin_f32 (float x) { return std::asin (x); }
191- inline float acos_f32 (float x) { return std::acos (x); }
192- inline float atan_f32 (float x) { return std::atan (x); }
193- inline float log10_f32 (float x) { return std::log10 (x); }
194- inline float log2_f32 (float x) { return std::log2 (x); }
195- inline float exp2_f32 (float x) { return std::exp2 (x); }
196- inline float pow_f32 (float x, float e) { return std::pow (x, e); }
197- inline float atan2_f32 (float y, float x) { return std::atan2 (y, x); }
130+
131+ NUMPY_FB_F32 (exp_f32, npy_float_math::npy_expf)
132+ NUMPY_FB_F32 (log_f32, npy_float_math::npy_logf)
133+ NUMPY_FB_F32 (sin_f32, npy_float_math::npy_sinf)
134+ NUMPY_FB_F32 (cos_f32, npy_float_math::npy_cosf)
135+ NUMPY_FB_F32 (tan_f32, std::tan)
136+ NUMPY_FB_F32 (asin_f32, std::asin)
137+ NUMPY_FB_F32 (acos_f32, std::acos)
138+ NUMPY_FB_F32 (atan_f32, std::atan)
139+ NUMPY_FB_F32 (log10_f32,std::log10)
140+ NUMPY_FB_F32 (log2_f32, std::log2)
141+ NUMPY_FB_F32 (exp2_f32, std::exp2)
142+ inline float pow_f32 (float x, float e) { return std::pow (x, e); }
143+ inline float atan2_f32 (float y, float x) { return std::atan2 (y, x); }
144+
145+ #undef NUMPY_FB_F64
146+ #undef NUMPY_FB_F32
198147
199148#endif // __AVX512F__
200149
201150// ============================================================================
202- // Template dispatchers
151+ // Template dispatchers — svml_impl<T> + free function templates
203152// ============================================================================
204153
205- template <typename T> struct svml_impl ;
206- template <> struct svml_impl <double > {
207- static double exp (double x) { return exp_f64 (x); }
208- static double log (double x) { return log_f64 (x); }
209- static double sin (double x) { return sin_f64 (x); }
210- static double cos (double x) { return cos_f64 (x); }
211- static double tan (double x) { return tan_f64 (x); }
212- static double asin (double x) { return asin_f64 (x); }
213- static double acos (double x) { return acos_f64 (x); }
214- static double atan (double x) { return atan_f64 (x); }
215- static double log10 (double x){ return log10_f64 (x); }
216- static double log2 (double x) { return log2_f64 (x); }
217- static double exp2 (double x) { return exp2_f64 (x); }
218- static double pow (double x, double e) { return pow_f64 (x, e); }
219- static double atan2 (double y, double x) { return atan2_f64 (y, x); }
220- };
221- template <> struct svml_impl <float > {
222- static float exp (float x) { return exp_f32 (x); }
223- static float log (float x) { return log_f32 (x); }
224- static float sin (float x) { return sin_f32 (x); }
225- static float cos (float x) { return cos_f32 (x); }
226- static float tan (float x) { return tan_f32 (x); }
227- static float asin (float x) { return asin_f32 (x); }
228- static float acos (float x) { return acos_f32 (x); }
229- static float atan (float x) { return atan_f32 (x); }
230- static float log10 (float x){ return log10_f32 (x); }
231- static float log2 (float x) { return log2_f32 (x); }
232- static float exp2 (float x) { return exp2_f32 (x); }
233- static float pow (float x, float e) { return pow_f32 (x, e); }
234- static float atan2 (float y, float x) { return atan2_f32 (y, x); }
154+ #define NUMPY_SVML_METHODS (T, suff ) \
155+ template <> struct svml_impl <T> { \
156+ static T exp (T x) { return exp_##suff (x); } \
157+ static T log (T x) { return log_##suff (x); } \
158+ static T sin (T x) { return sin_##suff (x); } \
159+ static T cos (T x) { return cos_##suff (x); } \
160+ static T tan (T x) { return tan_##suff (x); } \
161+ static T asin (T x) { return asin_##suff (x); } \
162+ static T acos (T x) { return acos_##suff (x); } \
163+ static T atan (T x) { return atan_##suff (x); } \
164+ static T log10 (T x){ return log10_##suff (x); } \
165+ static T log2 (T x) { return log2_##suff (x); } \
166+ static T exp2 (T x) { return exp2_##suff (x); } \
167+ static T pow (T x, T e) { return pow_##suff (x, e); } \
168+ static T atan2 (T y, T x) { return atan2_##suff (y, x); } \
235169};
236170
237- template <typename T> inline T exp (T x) { return svml_impl<T>::exp (x); }
238- template <typename T> inline T log (T x) { return svml_impl<T>::log (x); }
239- template <typename T> inline T sin (T x) { return svml_impl<T>::sin (x); }
240- template <typename T> inline T cos (T x) { return svml_impl<T>::cos (x); }
241- template <typename T> inline T tan (T x) { return svml_impl<T>::tan (x); }
242- template <typename T> inline T asin (T x) { return svml_impl<T>::asin (x); }
243- template <typename T> inline T acos (T x) { return svml_impl<T>::acos (x); }
244- template <typename T> inline T atan (T x) { return svml_impl<T>::atan (x); }
245- template <typename T> inline T log10 (T x) { return svml_impl<T>::log10 (x); }
246- template <typename T> inline T log2 (T x) { return svml_impl<T>::log2 (x); }
247- template <typename T> inline T exp2 (T x) { return svml_impl<T>::exp2 (x); }
171+ template <typename T> struct svml_impl ;
172+ NUMPY_SVML_METHODS (double , f64 )
173+ NUMPY_SVML_METHODS (float , f32 )
174+ #undef NUMPY_SVML_METHODS
175+
176+ // 1-arg dispatchers
177+ #define NUMPY_SVML_D1 (name ) \
178+ template <typename T> inline T name (T x) { return svml_impl<T>::name (x); }
179+ NUMPY_SVML_D1 (exp)
180+ NUMPY_SVML_D1 (log)
181+ NUMPY_SVML_D1 (sin)
182+ NUMPY_SVML_D1 (cos)
183+ NUMPY_SVML_D1 (tan)
184+ NUMPY_SVML_D1 (asin)
185+ NUMPY_SVML_D1 (acos)
186+ NUMPY_SVML_D1 (atan)
187+ NUMPY_SVML_D1 (log10)
188+ NUMPY_SVML_D1 (log2)
189+ NUMPY_SVML_D1 (exp2)
190+ #undef NUMPY_SVML_D1
191+
192+ // 2-arg dispatchers (parameter names differ: pow(x,e) vs atan2(y,x))
248193template <typename T> inline T pow (T x, T e) { return svml_impl<T>::pow (x, e); }
249194template <typename T> inline T atan2 (T y, T x) { return svml_impl<T>::atan2 (y, x); }
250195
0 commit comments