Skip to content

Commit 9b1a2dc

Browse files
author
peng.li24
committed
refactor: consolidate SVML wrappers with macros (-55 lines in svml_bridge.h)
1 parent 68d0328 commit 9b1a2dc

1 file changed

Lines changed: 121 additions & 176 deletions

File tree

numpy/svml_bridge.h

Lines changed: 121 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11
// SVML/npy bridge — resolve functions from numpy's _multiarray_umath.so.
22
//
3-
// numpy uses SVML for ALL transcendental functions when available:
4-
// f64: __svml_exp8, __svml_log8, __svml_sin8, __svml_cos8, __svml_tan8,
5-
// __svml_asin8, __svml_acos8, __svml_atan8, __svml_log108, __svml_log28,
6-
// __svml_exp28, npy_pow, npy_atan2
7-
// f32: __svml_expf16, __svml_logf16, __svml_sinf16, __svml_cosf16, __svml_tanf16,
8-
// __svml_asinf16, __svml_acosf16, __svml_atanf16, __svml_log10f16, __svml_log2f16,
9-
// __svml_exp2f16, npy_pow, npy_atan2
10-
//
3+
// numpy uses SVML for ALL transcendental functions when available.
114
// Call bridge_init(path_to_multiarray_umath_so) before first use.
125

136
#pragma once
@@ -42,73 +35,48 @@ inline void* resolve_svml(const char* name) {
4235
fprintf(stderr, "[SVML] resolve(%s) -> %p dlerror=%s\n", name, ptr, dlerror());
4336
return ptr;
4437
}
45-
fprintf(stderr, "[SVML] resolve(%s) -> NO HANDLE (g_svml_handle is null)\n", name);
38+
fprintf(stderr, "[SVML] resolve(%s) -> NO HANDLE\n", name);
4639
return nullptr;
4740
}
4841

49-
#ifdef __AVX512F__
50-
5142
// ============================================================================
52-
// float64 — SVML for most, npy_pow / npy_atan2 for pow/atan2
43+
// SVML wrapper macros — consolidate repetitive f64/f32 patterns
5344
// ============================================================================
5445

55-
inline double exp_f64(double x) {
56-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_exp8");
57-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
58-
return std::exp(x);
59-
}
60-
inline double log_f64(double x) {
61-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_log8");
62-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
63-
return std::log(x);
64-
}
65-
inline double sin_f64(double x) {
66-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_sin8");
67-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
68-
return std::sin(x);
69-
}
70-
inline double cos_f64(double x) {
71-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_cos8");
72-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
73-
return std::cos(x);
74-
}
75-
inline double tan_f64(double x) {
76-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_tan8");
77-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
78-
return std::tan(x);
79-
}
80-
inline double asin_f64(double x) {
81-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_asin8");
82-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
83-
return std::asin(x);
84-
}
85-
inline double acos_f64(double x) {
86-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_acos8");
87-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
88-
return std::acos(x);
89-
}
90-
inline double atan_f64(double x) {
91-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_atan8");
92-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
93-
return std::atan(x);
94-
}
95-
inline double log10_f64(double x) {
96-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_log108");
97-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
98-
return std::log10(x);
99-
}
100-
inline double log2_f64(double x) {
101-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_log28");
102-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
103-
return std::log2(x);
104-
}
105-
inline double exp2_f64(double x) {
106-
static auto fn = (__m512d (*)(__m512d))resolve_svml("__svml_exp28");
107-
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x)));
108-
return std::exp2(x);
109-
}
46+
#ifdef __AVX512F__
11047

111-
// numpy uses npy_pow / npy_atan2 (scalar C functions, not SVML vector)
48+
// 1-arg f64: __svml_XXX8 → __m512d → extract scalar
49+
// 1-arg f32: __svml_XXXf16 → __m512 → extract scalar
50+
#define NUMPY_SVML_F64(name, sym, fallback) \
51+
inline double name(double x) { \
52+
static auto fn = (__m512d (*)(__m512d)) \
53+
resolve_svml(sym); \
54+
if (fn) return _mm512_cvtsd_f64( \
55+
fn(_mm512_set1_pd(x))); \
56+
return fallback(x); \
57+
}
58+
#define NUMPY_SVML_F32(name, sym, fallback) \
59+
inline float name(float x) { \
60+
static auto fn = (__m512 (*)(__m512)) \
61+
resolve_svml(sym); \
62+
if (fn) return _mm512_cvtss_f32( \
63+
fn(_mm512_set1_ps(x))); \
64+
return fallback(x); \
65+
}
66+
67+
NUMPY_SVML_F64(exp_f64, "__svml_exp8", std::exp)
68+
NUMPY_SVML_F64(log_f64, "__svml_log8", std::log)
69+
NUMPY_SVML_F64(sin_f64, "__svml_sin8", std::sin)
70+
NUMPY_SVML_F64(cos_f64, "__svml_cos8", std::cos)
71+
NUMPY_SVML_F64(tan_f64, "__svml_tan8", std::tan)
72+
NUMPY_SVML_F64(asin_f64, "__svml_asin8", std::asin)
73+
NUMPY_SVML_F64(acos_f64, "__svml_acos8", std::acos)
74+
NUMPY_SVML_F64(atan_f64, "__svml_atan8", std::atan)
75+
NUMPY_SVML_F64(log10_f64,"__svml_log108",std::log10)
76+
NUMPY_SVML_F64(log2_f64, "__svml_log28", std::log2)
77+
NUMPY_SVML_F64(exp2_f64, "__svml_exp28", std::exp2)
78+
79+
// pow/atan2 — 2-arg scalar (npy_pow / npy_atan2)
11280
inline double pow_f64(double x, double e) {
11381
static auto fn = (double (*)(double, double))resolve_svml("npy_pow");
11482
if (fn) return fn(x, e);
@@ -120,131 +88,108 @@ inline double atan2_f64(double y, double x) {
12088
return std::atan2(y, x);
12189
}
12290

123-
// ============================================================================
124-
// float32 — SVML for ALL transcendental functions
125-
// ============================================================================
126-
127-
// exp/log/sin/cos — use numpy's own polynomial approximations for bit-exact results
91+
// f32: exp/log/sin/cos use numpy's own polynomial approximations
12892
inline float exp_f32(float x) { return npy_float_math::npy_expf(x); }
12993
inline float log_f32(float x) { return npy_float_math::npy_logf(x); }
13094
inline float sin_f32(float x) { return npy_float_math::npy_sinf(x); }
13195
inline float cos_f32(float x) { return npy_float_math::npy_cosf(x); }
132-
inline float tan_f32(float x) {
133-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_tanf16");
134-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
135-
return std::tan(x);
136-
}
137-
inline float asin_f32(float x) {
138-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_asinf16");
139-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
140-
return std::asin(x);
141-
}
142-
inline float acos_f32(float x) {
143-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_acosf16");
144-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
145-
return std::acos(x);
146-
}
147-
inline float atan_f32(float x) {
148-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_atanf16");
149-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
150-
return std::atan(x);
151-
}
152-
inline float log10_f32(float x) {
153-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_log10f16");
154-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
155-
return std::log10(x);
156-
}
157-
inline float log2_f32(float x) {
158-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_log2f16");
159-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
160-
return std::log2(x);
161-
}
162-
inline float exp2_f32(float x) {
163-
static auto fn = (__m512 (*)(__m512))resolve_svml("__svml_exp2f16");
164-
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x)));
165-
return std::exp2(x);
166-
}
167-
inline float pow_f32(float x, float e) { return std::pow(x, e); }
96+
97+
NUMPY_SVML_F32(tan_f32, "__svml_tanf16", std::tan)
98+
NUMPY_SVML_F32(asin_f32, "__svml_asinf16", std::asin)
99+
NUMPY_SVML_F32(acos_f32, "__svml_acosf16", std::acos)
100+
NUMPY_SVML_F32(atan_f32, "__svml_atanf16", std::atan)
101+
NUMPY_SVML_F32(log10_f32,"__svml_log10f16",std::log10)
102+
NUMPY_SVML_F32(log2_f32, "__svml_log2f16", std::log2)
103+
NUMPY_SVML_F32(exp2_f32, "__svml_exp2f16", std::exp2)
104+
105+
inline float pow_f32(float x, float e) { return std::pow(x, e); }
168106
inline float atan2_f32(float y, float x) { return std::atan2(y, x); }
169107

108+
#undef NUMPY_SVML_F64
109+
#undef NUMPY_SVML_F32
110+
170111
#else // !__AVX512F__
171112

172-
inline double exp_f64(double x) { return std::exp(x); }
173-
inline double log_f64(double x) { return std::log(x); }
174-
inline double sin_f64(double x) { return std::sin(x); }
175-
inline double cos_f64(double x) { return std::cos(x); }
176-
inline double tan_f64(double x) { return std::tan(x); }
177-
inline double asin_f64(double x) { return std::asin(x); }
178-
inline double acos_f64(double x) { return std::acos(x); }
179-
inline double atan_f64(double x) { return std::atan(x); }
180-
inline double log10_f64(double x){ return std::log10(x); }
181-
inline double log2_f64(double x) { return std::log2(x); }
182-
inline double exp2_f64(double x) { return std::exp2(x); }
183-
inline double pow_f64(double x, double e) { return std::pow(x, e); }
113+
// Non-AVX512: all 1-arg wrappers degrade to std or npy_math
114+
#define NUMPY_FB_F64(name, fallback) inline double name(double x) { return fallback(x); }
115+
#define NUMPY_FB_F32(name, fallback) inline float name(float x) { return fallback(x); }
116+
117+
NUMPY_FB_F64(exp_f64, std::exp)
118+
NUMPY_FB_F64(log_f64, std::log)
119+
NUMPY_FB_F64(sin_f64, std::sin)
120+
NUMPY_FB_F64(cos_f64, std::cos)
121+
NUMPY_FB_F64(tan_f64, std::tan)
122+
NUMPY_FB_F64(asin_f64, std::asin)
123+
NUMPY_FB_F64(acos_f64, std::acos)
124+
NUMPY_FB_F64(atan_f64, std::atan)
125+
NUMPY_FB_F64(log10_f64,std::log10)
126+
NUMPY_FB_F64(log2_f64, std::log2)
127+
NUMPY_FB_F64(exp2_f64, std::exp2)
128+
inline double pow_f64(double x, double e) { return std::pow(x, e); }
184129
inline double atan2_f64(double y, double x) { return std::atan2(y, x); }
185-
inline float exp_f32(float x) { return npy_float_math::npy_expf(x); }
186-
inline float log_f32(float x) { return npy_float_math::npy_logf(x); }
187-
inline float sin_f32(float x) { return npy_float_math::npy_sinf(x); }
188-
inline float cos_f32(float x) { return npy_float_math::npy_cosf(x); }
189-
inline float tan_f32(float x) { return std::tan(x); }
190-
inline float asin_f32(float x) { return std::asin(x); }
191-
inline float acos_f32(float x) { return std::acos(x); }
192-
inline float atan_f32(float x) { return std::atan(x); }
193-
inline float log10_f32(float x) { return std::log10(x); }
194-
inline float log2_f32(float x) { return std::log2(x); }
195-
inline float exp2_f32(float x) { return std::exp2(x); }
196-
inline float pow_f32(float x, float e) { return std::pow(x, e); }
197-
inline float atan2_f32(float y, float x) { return std::atan2(y, x); }
130+
131+
NUMPY_FB_F32(exp_f32, npy_float_math::npy_expf)
132+
NUMPY_FB_F32(log_f32, npy_float_math::npy_logf)
133+
NUMPY_FB_F32(sin_f32, npy_float_math::npy_sinf)
134+
NUMPY_FB_F32(cos_f32, npy_float_math::npy_cosf)
135+
NUMPY_FB_F32(tan_f32, std::tan)
136+
NUMPY_FB_F32(asin_f32, std::asin)
137+
NUMPY_FB_F32(acos_f32, std::acos)
138+
NUMPY_FB_F32(atan_f32, std::atan)
139+
NUMPY_FB_F32(log10_f32,std::log10)
140+
NUMPY_FB_F32(log2_f32, std::log2)
141+
NUMPY_FB_F32(exp2_f32, std::exp2)
142+
inline float pow_f32(float x, float e) { return std::pow(x, e); }
143+
inline float atan2_f32(float y, float x) { return std::atan2(y, x); }
144+
145+
#undef NUMPY_FB_F64
146+
#undef NUMPY_FB_F32
198147

199148
#endif // __AVX512F__
200149

201150
// ============================================================================
202-
// Template dispatchers
151+
// Template dispatchers — svml_impl<T> + free function templates
203152
// ============================================================================
204153

205-
template<typename T> struct svml_impl;
206-
template<> struct svml_impl<double> {
207-
static double exp(double x) { return exp_f64(x); }
208-
static double log(double x) { return log_f64(x); }
209-
static double sin(double x) { return sin_f64(x); }
210-
static double cos(double x) { return cos_f64(x); }
211-
static double tan(double x) { return tan_f64(x); }
212-
static double asin(double x) { return asin_f64(x); }
213-
static double acos(double x) { return acos_f64(x); }
214-
static double atan(double x) { return atan_f64(x); }
215-
static double log10(double x){ return log10_f64(x); }
216-
static double log2(double x) { return log2_f64(x); }
217-
static double exp2(double x) { return exp2_f64(x); }
218-
static double pow(double x, double e) { return pow_f64(x, e); }
219-
static double atan2(double y, double x) { return atan2_f64(y, x); }
220-
};
221-
template<> struct svml_impl<float> {
222-
static float exp(float x) { return exp_f32(x); }
223-
static float log(float x) { return log_f32(x); }
224-
static float sin(float x) { return sin_f32(x); }
225-
static float cos(float x) { return cos_f32(x); }
226-
static float tan(float x) { return tan_f32(x); }
227-
static float asin(float x) { return asin_f32(x); }
228-
static float acos(float x) { return acos_f32(x); }
229-
static float atan(float x) { return atan_f32(x); }
230-
static float log10(float x){ return log10_f32(x); }
231-
static float log2(float x) { return log2_f32(x); }
232-
static float exp2(float x) { return exp2_f32(x); }
233-
static float pow(float x, float e) { return pow_f32(x, e); }
234-
static float atan2(float y, float x) { return atan2_f32(y, x); }
154+
#define NUMPY_SVML_METHODS(T, suff) \
155+
template<> struct svml_impl<T> { \
156+
static T exp(T x) { return exp_##suff(x); } \
157+
static T log(T x) { return log_##suff(x); } \
158+
static T sin(T x) { return sin_##suff(x); } \
159+
static T cos(T x) { return cos_##suff(x); } \
160+
static T tan(T x) { return tan_##suff(x); } \
161+
static T asin(T x) { return asin_##suff(x); } \
162+
static T acos(T x) { return acos_##suff(x); } \
163+
static T atan(T x) { return atan_##suff(x); } \
164+
static T log10(T x){ return log10_##suff(x); } \
165+
static T log2(T x) { return log2_##suff(x); } \
166+
static T exp2(T x) { return exp2_##suff(x); } \
167+
static T pow(T x, T e) { return pow_##suff(x, e); } \
168+
static T atan2(T y, T x) { return atan2_##suff(y, x); } \
235169
};
236170

237-
template<typename T> inline T exp(T x) { return svml_impl<T>::exp(x); }
238-
template<typename T> inline T log(T x) { return svml_impl<T>::log(x); }
239-
template<typename T> inline T sin(T x) { return svml_impl<T>::sin(x); }
240-
template<typename T> inline T cos(T x) { return svml_impl<T>::cos(x); }
241-
template<typename T> inline T tan(T x) { return svml_impl<T>::tan(x); }
242-
template<typename T> inline T asin(T x) { return svml_impl<T>::asin(x); }
243-
template<typename T> inline T acos(T x) { return svml_impl<T>::acos(x); }
244-
template<typename T> inline T atan(T x) { return svml_impl<T>::atan(x); }
245-
template<typename T> inline T log10(T x) { return svml_impl<T>::log10(x); }
246-
template<typename T> inline T log2(T x) { return svml_impl<T>::log2(x); }
247-
template<typename T> inline T exp2(T x) { return svml_impl<T>::exp2(x); }
171+
template<typename T> struct svml_impl;
172+
NUMPY_SVML_METHODS(double, f64)
173+
NUMPY_SVML_METHODS(float, f32)
174+
#undef NUMPY_SVML_METHODS
175+
176+
// 1-arg dispatchers
177+
#define NUMPY_SVML_D1(name) \
178+
template<typename T> inline T name(T x) { return svml_impl<T>::name(x); }
179+
NUMPY_SVML_D1(exp)
180+
NUMPY_SVML_D1(log)
181+
NUMPY_SVML_D1(sin)
182+
NUMPY_SVML_D1(cos)
183+
NUMPY_SVML_D1(tan)
184+
NUMPY_SVML_D1(asin)
185+
NUMPY_SVML_D1(acos)
186+
NUMPY_SVML_D1(atan)
187+
NUMPY_SVML_D1(log10)
188+
NUMPY_SVML_D1(log2)
189+
NUMPY_SVML_D1(exp2)
190+
#undef NUMPY_SVML_D1
191+
192+
// 2-arg dispatchers (parameter names differ: pow(x,e) vs atan2(y,x))
248193
template<typename T> inline T pow(T x, T e) { return svml_impl<T>::pow(x, e); }
249194
template<typename T> inline T atan2(T y, T x) { return svml_impl<T>::atan2(y, x); }
250195

0 commit comments

Comments
 (0)