@@ -88,6 +88,21 @@ using sdot64_fn = float (const int64_t*, const float*, const int64_t*,
8888using ddot64_fn = double (const int64_t *, const double *, const int64_t *,
8989 const double *, const int64_t *);
9090
91+ // cblas_sgemm64_ / cblas_dgemm64_ — C BLAS interface, ILP64 (BLAS_SYMBOL_SUFFIX=64_)
92+ // Signature: (layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc)
93+ // layout : 101 = CblasRowMajor
94+ // transA/B: 111 = CblasNoTrans
95+ using cblas_sgemm64_fn = void (int , int , int ,
96+ int64_t , int64_t , int64_t ,
97+ float , const float *, int64_t ,
98+ const float *, int64_t ,
99+ float , float *, int64_t );
100+ using cblas_dgemm64_fn = void (int , int , int ,
101+ int64_t , int64_t , int64_t ,
102+ double , const double *, int64_t ,
103+ const double *, int64_t ,
104+ double , double *, int64_t );
105+
91106inline float blas_sdot (const float * x, const float * y, size_t n) {
92107 static auto fn = (sdot64_fn*)resolve_blas (" sdot_64_" );
93108 if (__builtin_expect (fn != nullptr , 1 )) {
@@ -111,16 +126,136 @@ inline double blas_ddot(const double* x, const double* y, size_t n) {
111126 return r;
112127}
113128
129+ // cblas_sgemv64_ / cblas_dgemv64_ — matrix-vector, ILP64
130+ // Signature: (layout, trans, M, N, alpha, A, lda, x, incx, beta, y, incy)
131+ using cblas_sgemv64_fn = void (int , int , int64_t , int64_t ,
132+ float , const float *, int64_t ,
133+ const float *, int64_t ,
134+ float , float *, int64_t );
135+ using cblas_dgemv64_fn = void (int , int , int64_t , int64_t ,
136+ double , const double *, int64_t ,
137+ const double *, int64_t ,
138+ double , double *, int64_t );
139+
140+ // y[M] = A[M×K] @ x[K] — 2D × 1D case
141+ inline void blas_sgemv (const float * A, const float * x, float * y, size_t M, size_t K) {
142+ static auto fn = (cblas_sgemv64_fn*)resolve_blas (" cblas_sgemv64_" );
143+ if (__builtin_expect (fn != nullptr , 1 )) {
144+ fn (101 , 111 , (int64_t )M, (int64_t )K, 1 .0f , A, (int64_t )K,
145+ x, 1 , 0 .0f , y, 1 );
146+ return ;
147+ }
148+ for (size_t i = 0 ; i < M; ++i) {
149+ float s = 0 .0f ;
150+ for (size_t k = 0 ; k < K; ++k) s += A[i*K+k] * x[k];
151+ y[i] = s;
152+ }
153+ }
154+ inline void blas_dgemv (const double * A, const double * x, double * y, size_t M, size_t K) {
155+ static auto fn = (cblas_dgemv64_fn*)resolve_blas (" cblas_dgemv64_" );
156+ if (__builtin_expect (fn != nullptr , 1 )) {
157+ fn (101 , 111 , (int64_t )M, (int64_t )K, 1.0 , A, (int64_t )K,
158+ x, 1 , 0.0 , y, 1 );
159+ return ;
160+ }
161+ for (size_t i = 0 ; i < M; ++i) {
162+ double s = 0.0 ;
163+ for (size_t k = 0 ; k < K; ++k) s += A[i*K+k] * x[k];
164+ y[i] = s;
165+ }
166+ }
167+
168+ // y[N] = B^T[K×N] @ a[K] — 1D × 2D case (Trans=112)
169+ inline void blas_sgemv_t (const float * B, const float * a, float * y, size_t K, size_t N) {
170+ static auto fn = (cblas_sgemv64_fn*)resolve_blas (" cblas_sgemv64_" );
171+ if (__builtin_expect (fn != nullptr , 1 )) {
172+ fn (101 , 112 , (int64_t )K, (int64_t )N, 1 .0f , B, (int64_t )N,
173+ a, 1 , 0 .0f , y, 1 );
174+ return ;
175+ }
176+ for (size_t j = 0 ; j < N; ++j) {
177+ float s = 0 .0f ;
178+ for (size_t k = 0 ; k < K; ++k) s += B[k*N+j] * a[k];
179+ y[j] = s;
180+ }
181+ }
182+ inline void blas_dgemv_t (const double * B, const double * a, double * y, size_t K, size_t N) {
183+ static auto fn = (cblas_dgemv64_fn*)resolve_blas (" cblas_dgemv64_" );
184+ if (__builtin_expect (fn != nullptr , 1 )) {
185+ fn (101 , 112 , (int64_t )K, (int64_t )N, 1.0 , B, (int64_t )N,
186+ a, 1 , 0.0 , y, 1 );
187+ return ;
188+ }
189+ for (size_t j = 0 ; j < N; ++j) {
190+ double s = 0.0 ;
191+ for (size_t k = 0 ; k < K; ++k) s += B[k*N+j] * a[k];
192+ y[j] = s;
193+ }
194+ }
195+
196+ // C = A @ B (all row-major) C[M×N] = A[M×K] @ B[K×N]
197+ // Uses cblas_sgemm64_ — same kernel numpy.matmul calls → 0 ULP by construction.
198+ inline void blas_sgemm (const float * A, const float * B, float * C,
199+ size_t M, size_t K, size_t N) {
200+ static auto fn = (cblas_sgemm64_fn*)resolve_blas (" cblas_sgemm64_" );
201+ if (__builtin_expect (fn != nullptr , 1 )) {
202+ fn (101 , 111 , 111 , // RowMajor, NoTrans, NoTrans
203+ (int64_t )M, (int64_t )N, (int64_t )K,
204+ 1 .0f , A, (int64_t )K, B, (int64_t )N,
205+ 0 .0f , C, (int64_t )N);
206+ return ;
207+ }
208+ // Fallback (no OpenBLAS): naive triple loop — not bit-exact but always correct
209+ for (size_t i = 0 ; i < M; ++i)
210+ for (size_t j = 0 ; j < N; ++j) {
211+ float s = 0 .0f ;
212+ for (size_t k = 0 ; k < K; ++k) s += A[i*K+k] * B[k*N+j];
213+ C[i*N+j] = s;
214+ }
215+ }
216+
217+ inline void blas_dgemm (const double * A, const double * B, double * C,
218+ size_t M, size_t K, size_t N) {
219+ static auto fn = (cblas_dgemm64_fn*)resolve_blas (" cblas_dgemm64_" );
220+ if (__builtin_expect (fn != nullptr , 1 )) {
221+ fn (101 , 111 , 111 ,
222+ (int64_t )M, (int64_t )N, (int64_t )K,
223+ 1.0 , A, (int64_t )K, B, (int64_t )N,
224+ 0.0 , C, (int64_t )N);
225+ return ;
226+ }
227+ for (size_t i = 0 ; i < M; ++i)
228+ for (size_t j = 0 ; j < N; ++j) {
229+ double s = 0.0 ;
230+ for (size_t k = 0 ; k < K; ++k) s += A[i*K+k] * B[k*N+j];
231+ C[i*N+j] = s;
232+ }
233+ }
234+
114235// Template dispatcher
115236template <typename T> struct blas_ops ;
116237
117238template <> struct blas_ops <float > {
118- static float dot (const float * x, const float * y, size_t n) { return blas_sdot (x, y, n); }
119- static float norm (const float * x, size_t n) { return std::sqrt (blas_sdot (x, x, n)); }
239+ static float dot (const float * x, const float * y, size_t n) { return blas_sdot (x, y, n); }
240+ static float norm (const float * x, size_t n) { return std::sqrt (blas_sdot (x, x, n)); }
241+ static void gemm (const float * A, const float * B, float * C,
242+ size_t M, size_t K, size_t N) { blas_sgemm (A, B, C, M, K, N); }
243+ // y[M] = A[M×K] @ x[K]
244+ static void gemv (const float * A, const float * x, float * y,
245+ size_t M, size_t K) { blas_sgemv (A, x, y, M, K); }
246+ // y[N] = B^T @ a[K] (1D × 2D case)
247+ static void gemvt (const float * B, const float * a, float * y,
248+ size_t K, size_t N) { blas_sgemv_t (B, a, y, K, N); }
120249};
121250template <> struct blas_ops <double > {
122- static double dot (const double * x, const double * y, size_t n) { return blas_ddot (x, y, n); }
123- static double norm (const double * x, size_t n) { return std::sqrt (blas_ddot (x, x, n)); }
251+ static double dot (const double * x, const double * y, size_t n) { return blas_ddot (x, y, n); }
252+ static double norm (const double * x, size_t n) { return std::sqrt (blas_ddot (x, x, n)); }
253+ static void gemm (const double * A, const double * B, double * C,
254+ size_t M, size_t K, size_t N) { blas_dgemm (A, B, C, M, K, N); }
255+ static void gemv (const double * A, const double * x, double * y,
256+ size_t M, size_t K) { blas_dgemv (A, x, y, M, K); }
257+ static void gemvt (const double * B, const double * a, double * y,
258+ size_t K, size_t N) { blas_dgemv_t (B, a, y, K, N); }
124259};
125260
126261} // namespace detail
0 commit comments