diff --git a/cblas.h b/cblas.h index 8395f1b8b2..068fb34bde 100644 --- a/cblas.h +++ b/cblas.h @@ -40,6 +40,8 @@ extern "C" { /*Set the number of threads on runtime.*/ void openblas_set_num_threads(int num_threads); void goto_set_num_threads(int num_threads); +// "Local" means this number is used when OpenBLAS notices that +// it is already in an OpenMP parallel region (`omp_in_parallel()`). int openblas_set_num_threads_local(int num_threads); /*Get the number of threads on runtime.*/ diff --git a/common_thread.h b/common_thread.h index 4a8db682bf..3d34ebafe2 100644 --- a/common_thread.h +++ b/common_thread.h @@ -138,31 +138,30 @@ typedef struct blas_queue { extern int blas_server_avail; extern int blas_omp_number_max; extern int blas_omp_threads_local; +extern int blas_is_num_threads_set_explicitly; static __inline int num_cpu_avail(int level) { #ifdef USE_OPENMP -int openmp_nthreads; - openmp_nthreads=omp_get_max_threads(); - if (omp_in_parallel()) openmp_nthreads = blas_omp_threads_local; -#endif + /* If the user explicitly called openblas_set_num_threads(), + respect that setting instead of overriding it with + `omp_get_max_threads()` below (which is to get a default + in case the user hasn't made an explicit choice). */ + if (blas_is_num_threads_set_explicitly) { + return blas_cpu_number; + } -#ifndef USE_OPENMP - if (blas_cpu_number == 1 -#else - if (openmp_nthreads == 1 -#endif - ) return 1; + int openmp_nthreads = omp_in_parallel() ? blas_omp_threads_local : omp_get_max_threads(); -#ifdef USE_OPENMP - if (openmp_nthreads > blas_omp_number_max){ + if (openmp_nthreads > blas_omp_number_max) { #ifdef DEBUG - fprintf(stderr,"WARNING - more OpenMP threads requested (%d) than available (%d)\n",openmp_nthreads,blas_omp_number_max); + fprintf(stderr, "WARNING - more OpenMP threads requested (%d) than available (%d)\n", openmp_nthreads, blas_omp_number_max); #endif - openmp_nthreads = blas_omp_number_max; - } - if (blas_cpu_number != openmp_nthreads) { - goto_set_num_threads(openmp_nthreads); + openmp_nthreads = blas_omp_number_max; + } + + if (blas_cpu_number != openmp_nthreads) { + goto_set_num_threads(openmp_nthreads); // mutates `blas_cpu_number` } #endif diff --git a/driver/others/blas_server_omp.c b/driver/others/blas_server_omp.c index 38b48fc842..1f6e35e184 100644 --- a/driver/others/blas_server_omp.c +++ b/driver/others/blas_server_omp.c @@ -69,7 +69,8 @@ int blas_server_avail = 0; int blas_omp_number_max = 0; -int blas_omp_threads_local = 1; +int blas_omp_threads_local = 1; // num threads to use when already inside omp_in_parallel() +int blas_is_num_threads_set_explicitly = 0; // tracks whether the user called openblas_set_num_threads() extern int openblas_omp_adaptive_env(void); @@ -122,7 +123,7 @@ void goto_set_num_threads(int num_threads) { } void openblas_set_num_threads(int num_threads) { - + blas_is_num_threads_set_explicitly = 1; goto_set_num_threads(num_threads); }