Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
}

static int get_mmq_y_host(const int cc) {
if (GGML_CUDA_CC_IS_RDNA3_5(cc)) {
return 64;
}
return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
}
Expand All @@ -155,7 +158,9 @@ if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {

static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIP)
#if defined(RDNA1)
#if defined(RDNA3_5)
return 64;
#elif defined(RDNA1)
return 64;
#else
return 128;
Expand Down Expand Up @@ -296,6 +301,9 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)

#if defined(GGML_USE_HIP)
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
if (GGML_CUDA_CC_IS_RDNA3_5(cc)) {
return 4;
}
return amd_mfma_available(cc) ? 8 : 256/warp_size;
}
#else
Expand All @@ -305,7 +313,9 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
#endif // (GGML_USE_HIP)

static constexpr __device__ int mmq_get_nwarps_device() {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3_5)
return 4;
#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
return 8;
#else
return 256/ggml_cuda_get_physical_warp_size();
Expand Down
Loading