diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 1749b5734a..1685120ce3 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -38,6 +38,8 @@ namespace transformer_engine { std::string to_string(const DType type); std::string to_string(const NVTEScalingMode &mode); +inline std::string to_string_like(const DType &val) { return to_string(val); } + inline bool is_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } @@ -619,140 +621,149 @@ struct TypeInfo { #define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing #endif -#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kByte: { \ - using type = unsigned char; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt16: { \ - using type = int16_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt32: { \ - using type = int32_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt64: { \ - using type = int64_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E8M0: { \ - using type = byte; \ - { __VA_ARGS__ } \ - } break; \ - SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kByte: { \ + using type = unsigned char; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt16: { \ + using type = int16_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Byte, Int16, Int32, Int64, Float32, " \ + "Float16, BFloat16, Float8E4M3, Float8E5M2, " \ + "Float8E8M0, Float4E2M1."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E4M3, Float8E5M2."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported output dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E5M2, Float8E4M3."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type, expected Float32 or BFloat16."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, BFloat16."); \ } // Add a pack_size argument to select the packed type for FP4 @@ -764,80 +775,90 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected: Float4E2M1."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float8E5M2, Float8E4M3."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: \ - case DType::kFloat8E4M3: { \ - NVTE_ERROR("FP8 type not instantiated for input."); \ - } break; \ - case DType::kFloat4E2M1: { \ - NVTE_ERROR("FP4 type not instantiated for input."); \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: \ + case DType::kFloat8E4M3: { \ + NVTE_ERROR("FP8 dtype ", to_string(static_cast(dtype)), \ + " is not instantiated for input. " \ + "Expected one of: Float32, Float16, BFloat16."); \ + } break; \ + case DType::kFloat4E2M1: { \ + NVTE_ERROR( \ + "FP4 dtype Float4E2M1 is not instantiated " \ + "for input. Expected one of: Float32, Float16, " \ + "BFloat16."); \ + } break; \ + default: \ + NVTE_ERROR("Unsupported input dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat16: { \ - using type = fp16; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kBFloat16: { \ - using type = bf16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - NVTE_ERROR("Invalid type for 16 bit."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Unsupported 16-bit dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ - switch (SCALE_DIM) { \ - case 1: { \ - constexpr size_t DIM = 1; \ - { __VA_ARGS__ } \ - } break; \ - case 32: { \ - constexpr size_t DIM = 32; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Invalid size of the MX scaling factor."); \ - } \ +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported MX scaling factor dimension ", SCALE_DIM, \ + ". Expected one of: 1, 32."); \ + } \ } #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index abdce7fdac..6a136c67e4 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,6 +15,88 @@ #include "fused_attn_fp8.h" #include "utils.h" +namespace transformer_engine { + +std::string to_string(NVTE_QKV_Layout layout) { + switch (layout) { + case NVTE_SB3HD: + return "NVTE_SB3HD"; + case NVTE_SBH3D: + return "NVTE_SBH3D"; + case NVTE_SBHD_SB2HD: + return "NVTE_SBHD_SB2HD"; + case NVTE_SBHD_SBH2D: + return "NVTE_SBHD_SBH2D"; + case NVTE_SBHD_SBHD_SBHD: + return "NVTE_SBHD_SBHD_SBHD"; + case NVTE_BS3HD: + return "NVTE_BS3HD"; + case NVTE_BSH3D: + return "NVTE_BSH3D"; + case NVTE_BSHD_BS2HD: + return "NVTE_BSHD_BS2HD"; + case NVTE_BSHD_BSH2D: + return "NVTE_BSHD_BSH2D"; + case NVTE_BSHD_BSHD_BSHD: + return "NVTE_BSHD_BSHD_BSHD"; + case NVTE_T3HD: + return "NVTE_T3HD"; + case NVTE_TH3D: + return "NVTE_TH3D"; + case NVTE_THD_T2HD: + return "NVTE_THD_T2HD"; + case NVTE_THD_TH2D: + return "NVTE_THD_TH2D"; + case NVTE_THD_THD_THD: + return "NVTE_THD_THD_THD"; + case NVTE_SBHD_BSHD_BSHD: + return "NVTE_SBHD_BSHD_BSHD"; + case NVTE_BSHD_SBHD_SBHD: + return "NVTE_BSHD_SBHD_SBHD"; + case NVTE_THD_BSHD_BSHD: + return "NVTE_THD_BSHD_BSHD"; + case NVTE_THD_SBHD_SBHD: + return "NVTE_THD_SBHD_SBHD"; + case NVTE_Paged_KV_BSHD_BSHD_BSHD: + return "NVTE_Paged_KV_BSHD_BSHD_BSHD"; + case NVTE_Paged_KV_BSHD_SBHD_SBHD: + return "NVTE_Paged_KV_BSHD_SBHD_SBHD"; + case NVTE_Paged_KV_SBHD_BSHD_BSHD: + return "NVTE_Paged_KV_SBHD_BSHD_BSHD"; + case NVTE_Paged_KV_SBHD_SBHD_SBHD: + return "NVTE_Paged_KV_SBHD_SBHD_SBHD"; + case NVTE_Paged_KV_THD_BSHD_BSHD: + return "NVTE_Paged_KV_THD_BSHD_BSHD"; + case NVTE_Paged_KV_THD_SBHD_SBHD: + return "NVTE_Paged_KV_THD_SBHD_SBHD"; + default: + return "UNKNOWN_QKV_LAYOUT(" + std::to_string(static_cast(layout)) + ")"; + } +} + +std::string to_string(NVTE_QKV_Format format) { + switch (format) { + case NVTE_SBHD: + return "NVTE_SBHD"; + case NVTE_BSHD: + return "NVTE_BSHD"; + case NVTE_THD: + return "NVTE_THD"; + case NVTE_BSHD_2SBHD: + return "NVTE_BSHD_2SBHD"; + case NVTE_SBHD_2BSHD: + return "NVTE_SBHD_2BSHD"; + case NVTE_THD_2BSHD: + return "NVTE_THD_2BSHD"; + case NVTE_THD_2SBHD: + return "NVTE_THD_2SBHD"; + default: + return "UNKNOWN_QKV_FORMAT(" + std::to_string(static_cast(format)) + ")"; + } +} + +} // namespace transformer_engine + // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { switch (qkv_layout) { @@ -50,7 +132,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), + " in nvte_get_qkv_layout_group."); } } @@ -90,7 +173,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), + " in nvte_get_qkv_format."); } } @@ -109,7 +193,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), + " in nvte_get_q_format."); } } @@ -128,7 +213,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), + " in nvte_get_kv_format."); } } diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 60e731d990..372efdc490 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -250,46 +250,49 @@ __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future -#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported router probs dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kInt32: { \ - using type = int32_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt64: { \ - using type = int64_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Invalid type."); \ +#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported router index dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Int32, Int64, BFloat16, " \ + "Float32."); \ } } // namespace fused_router } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh index eb99edc4d3..aa2bde4203 100644 --- a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -326,17 +326,17 @@ void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, // Check can implement the kernel. if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to implement CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } // Initialize the kernel. if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to initialize CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } // Execute the kernel in the current stream. if (gemm.run(stream) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to run CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } } diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 852b418b39..11f12775c5 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -116,7 +116,9 @@ void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* beta_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Backward normalization should not call the forward execute function!"); + NVTE_ERROR( + "Backward normalization should not call the forward execute function. " + "Use the backward-specific execute overload instead."); } template @@ -165,7 +167,9 @@ void TeNormalizationPlan::execute(void* x_dptr, void* gamma void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Forward normalization should not call the backward execute function!"); + NVTE_ERROR( + "Forward normalization should not call the backward execute function. " + "Use the forward-specific execute overload instead."); } template <> diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index cd02074fbd..1875f4f690 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -650,7 +650,7 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - NVTE_ERROR("Invalid tensor"); + NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_shape"); } // Determine tensor shape depending on tensor format @@ -662,7 +662,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - NVTE_ERROR("Invalid tensor"); + NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_columnwise_shape"); } const std::vector &shape = t->columnwise_data.shape; return nvte_make_shape(shape.data(), shape.size()); diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 0e286009a5..3a8536587c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -463,7 +463,8 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size std::is_same_v) { dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { - NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + NVTE_ERROR( + "Invalid output type for blockwise transpose (must be FP8: Float8E4M3 or Float8E5M2)."); } CUtensorMap tensor_map_output_trans{}; diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e5d75e1501..861ddfadae 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -165,13 +165,25 @@ def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): kv_max_seqlen = q_max_seqlen num_gqa_groups = attn_heads v_head_dim = q_head_dim - assert nqkv == 3 + assert nqkv == 3, ( + f"Expected nqkv == 3 for qkvpacked layout, but got nqkv={nqkv} from" + f" q_aval.shape={q_aval.shape}" + ) elif qkv_layout.is_kvpacked(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == v_head_dim - assert nkv == 2 + assert q_batch_shape == kv_batch_shape, ( + f"Mismatched batch shapes for kvpacked layout: q_batch_shape={q_batch_shape}," + f" kv_batch_shape={kv_batch_shape}" + ) + assert q_head_dim == v_head_dim, ( + f"Mismatched head dims for kvpacked layout: q_head_dim={q_head_dim}," + f" v_head_dim={v_head_dim}" + ) + assert nkv == 2, ( + f"Expected nkv == 2 for kvpacked layout, but got nkv={nkv} from" + f" k_aval.shape={k_aval.shape}" + ) elif qkv_layout.is_separate(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape @@ -244,9 +256,13 @@ def check_seed(self, seed, dropout_probability, is_training): ) seed = seed.astype(self.rng_state_dtype) - assert seed.dtype == self.rng_state_dtype + assert ( + seed.dtype == self.rng_state_dtype + ), f"Expected seed.dtype={self.rng_state_dtype}, but got seed.dtype={seed.dtype}" # Backend takes an int64_t seed, so only the first two u32 elements are taken - assert seed.size >= self.seed_size + assert ( + seed.size >= self.seed_size + ), f"Expected seed.size >= {self.seed_size}, but got seed.size={seed.size}" return seed @@ -363,7 +379,9 @@ def abstract( # 32-bit unsigned int to get the buffer size we need in the C++ kernel checker = _FusedAttnRNGStateChecker() seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype + assert ( + seed_dtype == checker.rng_state_dtype + ), f"Expected seed_dtype={checker.rng_state_dtype}, but got seed_dtype={seed_dtype}" rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) @@ -408,11 +426,19 @@ def abstract( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - assert softmax_offset_aval.dtype == jnp.float32 + assert ( + softmax_offset_aval.dtype == jnp.float32 + ), f"Expected softmax_offset_aval.dtype=float32, but got {softmax_offset_aval.dtype}" if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: - assert softmax_offset_aval.shape == (1, attn_heads, 1, 1) + assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), ( + f"Expected softmax_offset_aval.shape=(1, {attn_heads}, 1, 1) for" + f" {config.softmax_type}, but got {softmax_offset_aval.shape}" + ) else: - assert softmax_offset_aval.shape == (0,) + assert softmax_offset_aval.shape == (0,), ( + "Expected softmax_offset_aval.shape=(0,) for VANILLA_SOFTMAX, but got" + f" {softmax_offset_aval.shape}" + ) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @@ -533,7 +559,9 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): - assert FusedAttnFwdPrimitive.inner_primitive is not None + assert ( + FusedAttnFwdPrimitive.inner_primitive is not None + ), "FusedAttnFwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -627,7 +655,9 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None + assert ( + FusedAttnFwdPrimitive.outer_primitive is not None + ), "FusedAttnFwdPrimitive.outer_primitive has not been registered" q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim @@ -778,8 +808,15 @@ def abstract( v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype, ( + f"Mismatched dtypes: q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}," + f" bias_dtype={bias_dtype}, doutput_dtype={doutput_dtype}" + ) + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, ( + "Mismatched seqlen dtypes:" + f" q_seqlen_or_cu_seqlen_aval.dtype={q_seqlen_or_cu_seqlen_aval.dtype}," + f" kv_seqlen_or_cu_seqlen_aval.dtype={kv_seqlen_or_cu_seqlen_aval.dtype}" + ) ( batch_shape, @@ -983,7 +1020,9 @@ def impl( _kv_segment_pos, config, ): - assert FusedAttnBwdPrimitive.inner_primitive is not None + assert ( + FusedAttnBwdPrimitive.inner_primitive is not None + ), "FusedAttnBwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -1023,7 +1062,9 @@ def convert_to_2d(offsets, batch, max_seqlen): batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( q, k, v, config.qkv_layout ) - assert len(batch) == 1 + assert ( + len(batch) == 1 + ), f"Expected len(batch) == 1, but got len(batch)={len(batch)}, batch={batch}" kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 @@ -1082,7 +1123,9 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None + assert ( + FusedAttnBwdPrimitive.outer_primitive is not None + ), "FusedAttnBwdPrimitive.outer_primitive has not been registered" q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim @@ -3396,7 +3439,9 @@ def fused_attn_fwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: @@ -3414,10 +3459,16 @@ def fused_attn_fwd( softmax_offset, (None, HEAD_AXES, None, None) ) else: - assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX + assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX, ( + "Expected VANILLA_SOFTMAX when softmax_offset is None and not OFF_BY_ONE_SOFTMAX," + f" but got softmax_type={softmax_type}" + ) softmax_offset = jnp.zeros(0, dtype=jnp.float32) else: - assert softmax_offset.dtype == jnp.float32 + assert softmax_offset.dtype == jnp.float32, ( + "Expected softmax_offset.dtype=float32, but got" + f" softmax_offset.dtype={softmax_offset.dtype}" + ) # Shard by heads dimension if not VANILLA_SOFTMAX if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: softmax_offset = with_sharding_constraint_by_logical_axes( @@ -3556,7 +3607,9 @@ def fused_attn_bwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias with type={type(bias)}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 70557f29c7..4506adf33b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -177,17 +177,26 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ flatten_axis=flatten_axis, ) - assert not isinstance(lhs_q, ScaledTensor2x) - assert not isinstance(rhs_q, ScaledTensor2x) + if isinstance(lhs_q, ScaledTensor2x): + raise TypeError( + "Expected lhs_q to not be ScaledTensor2x after quantization, but got" + f" type={type(lhs_q)}" + ) + if isinstance(rhs_q, ScaledTensor2x): + raise TypeError( + "Expected rhs_q to not be ScaledTensor2x after quantization, but got" + f" type={type(rhs_q)}" + ) def has_rht_applied(q: AbstractBaseTensor) -> bool: return isinstance(q, ScaledTensor1x) and q.has_rht_applied - assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), ( - "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized" - " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the" - " GEMM." - ) + if has_rht_applied(lhs_q) != has_rht_applied(rhs_q): + raise ValueError( + "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be" + " quantized with RHT as well. This is to ensure the RHT is applied to both and will" + " cancel out in the GEMM." + ) return lhs_q, rhs_q @@ -284,14 +293,15 @@ def collective_gemm_bootstrap( this function with its own unique process_id. """ - assert ( - num_devices_per_process == 1 and jax.local_device_count() == 1 - ), "Only single device per process is supported at the moment!" - assert num_total_devices % num_devices_per_process == 0, ( - f"Invalid num_total_devices={num_total_devices}," - f" num_devices_per_process={num_devices_per_process}" - ) - assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + if not (num_devices_per_process == 1 and jax.local_device_count() == 1): + raise RuntimeError("Only single device per process is supported at the moment!") + if num_total_devices % num_devices_per_process != 0: + raise ValueError( + f"Invalid num_total_devices={num_total_devices}," + f" num_devices_per_process={num_devices_per_process}" + ) + if not 0 <= process_id < num_total_devices: + raise ValueError(f"Invalid process_id={process_id}") initialize_cgemm_communicator( num_total_devices, num_devices_per_process, @@ -390,10 +400,11 @@ def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage alignment = 32 if scaling_mode.is_nvfp4_scaling else 16 - assert contracting_size % alignment == 0, ( - f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" - f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" - ) + if contracting_size % alignment != 0: + raise ValueError( + f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" + f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" + ) class GemmPrimitive(BasePrimitive): @@ -439,57 +450,63 @@ def _dims_are_consecutive(dims): lhs_contracting_dims, rhs_contracting_dims, ) = map(sanitize_dims, operand_ndims, contracting_dims) - assert _dims_are_consecutive(lhs_contracting_dims), ( - "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " - f"{lhs_contracting_dims}." - ) - assert _dims_are_consecutive(rhs_contracting_dims), ( - "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " - f"{rhs_contracting_dims}." - ) + if not _dims_are_consecutive(lhs_contracting_dims): + raise ValueError( + "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " + f"{lhs_contracting_dims}." + ) + if not _dims_are_consecutive(rhs_contracting_dims): + raise ValueError( + "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " + f"{rhs_contracting_dims}." + ) lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), (lhs_contracting_dims, rhs_contracting_dims), ) - assert lhs_contracting_size == rhs_contracting_size, ( - "cuBLAS GEMM operands have incompatible contracting dimensions: " - f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." - ) + if lhs_contracting_size != rhs_contracting_size: + raise ValueError( + f"cuBLAS GEMM operands have incompatible contracting dimensions: {lhs.shape} @ idx" + f" {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) assert_cublas_requirements(scaling_mode, lhs_contracting_size, "LHS") assert_cublas_requirements(scaling_mode, rhs_contracting_size, "RHS") lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) if scaling_mode != ScalingMode.NO_SCALING: - assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes( - lhs.dtype, rhs.dtype - ), ( - "cuBLAS GEMM quantized operands have incompatible data types: " - f"{lhs.dtype} x {rhs.dtype}." - ) - assert ( - lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 - ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + if not ( + scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype) + ): + raise ValueError( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + if not (lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0): + raise ValueError( + "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + ) if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING and not is_fp8_gemm_with_all_layouts_supported() ): - assert not lhs_is_transposed and rhs_is_transposed, ( - "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " - "require non-transposed LHS and transposed RHS operands " - "(`contracting_dims=((-1, ), (-1, ))`)." - ) + if lhs_is_transposed or not rhs_is_transposed: + raise ValueError( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) else: - assert lhs.dtype == rhs.dtype, ( - "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." - f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" - ) + if lhs.dtype != rhs.dtype: + raise ValueError( + "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." + f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" + ) # Determine output shape and dtype - assert ( - dtypes.canonicalize_dtype(out_dtype).itemsize > 1 - ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." + if not dtypes.canonicalize_dtype(out_dtype).itemsize > 1: + raise ValueError("cuBLAS GEMM custom op does not support 8-bit quantized output types.") lhs_non_contracting_shape, rhs_non_contracting_shape = map( lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], (lhs.shape, rhs.shape), @@ -500,7 +517,8 @@ def _dims_are_consecutive(dims): # Adjust output shape for comm+GEMM overlap if not collective_op.is_none and not is_outer: # Inner abstract - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + if sequence_dim != 1: + raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") overlap_out_shape = list(out_shape).copy() if collective_op.is_all_gather: overlap_out_shape[1] *= tpsp_axis_size() @@ -508,23 +526,34 @@ def _dims_are_consecutive(dims): overlap_out_shape[sequence_dim] = ( overlap_out_shape[sequence_dim] // tpsp_axis_size() ) - assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}" + if out_dtype != jnp.bfloat16: + raise ValueError(f"Unsupported out_dtype={out_dtype}") output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) # Validate bias when present (bias.size > 0 means fuse bias) if bias.size > 0: - assert bias.shape == tuple(rhs_non_contracting_shape), ( - "cuBLAS GEMM bias tensor has incorrect shape, " - f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." + if bias.shape != tuple(rhs_non_contracting_shape): + raise ValueError( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." + ) + if bias.dtype != out_dtype: + raise ValueError( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {out_dtype} but found {bias.dtype}." + ) + + if alpha.size != 1 or alpha.dtype != jnp.float32: + raise ValueError( + f"Expected alpha to be a single float32 scalar, but got alpha.size={alpha.size}," + f" alpha.dtype={alpha.dtype}" ) - assert bias.dtype == out_dtype, ( - "cuBLAS GEMM bias tensor has incorrect data type, " - f"expected {out_dtype} but found {bias.dtype}." + if beta.size != 1 or beta.dtype != jnp.float32: + raise ValueError( + f"Expected beta to be a single float32 scalar, but got beta.size={beta.size}," + f" beta.dtype={beta.dtype}" ) - assert alpha.size == 1 and alpha.dtype == jnp.float32 - assert beta.size == 1 and beta.dtype == jnp.float32 - # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() # NVFP4 swizzling happen in via nvte kernel instead of JAX transposes @@ -629,16 +658,19 @@ def impl( and not is_outer and not lhs.shape[0] == 1 ): - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + if sequence_dim != 1: + raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") original_shape = lhs.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) + if original_shape[0] % dp_or_fsdp_axis_size() != 0 and original_shape[0] != 1: + raise ValueError( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + if original_shape[1] % tpsp_axis_size() != 0 and original_shape[1] != 1: + raise ValueError( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) reshaped = lhs.reshape( dp_or_fsdp_axis_size(), int(original_shape[0] / dp_or_fsdp_axis_size()), @@ -673,16 +705,19 @@ def impl( and not is_outer and not output.shape[0] == 1 ): - assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + if sequence_dim != 1: + raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") original_shape = output.shape - assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( - f"Original_shape[0]={original_shape[0]} is not divisible by" - f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" - ) - assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( - f"Original_shape[1]={original_shape[1]} is not divisible by" - f" tpsp_axis_size()={tpsp_axis_size()}" - ) + if original_shape[0] % dp_or_fsdp_axis_size() != 0 and original_shape[0] != 1: + raise ValueError( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + if original_shape[1] % tpsp_axis_size() != 0 and original_shape[1] != 1: + raise ValueError( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) reshaped = output.reshape( tpsp_axis_size(), dp_or_fsdp_axis_size(), @@ -745,13 +780,15 @@ def batcher( is_outer, ): del transpose_batch_sequence, sequence_dim, is_outer - assert GemmPrimitive.outer_primitive is not None + if GemmPrimitive.outer_primitive is None: + raise RuntimeError("GemmPrimitive.outer_primitive has not been registered") lhs_bdims, _, rhs_bdims, *_ = batch_dims # Batched GEMM is not supported - assert ( - lhs_bdims is None and rhs_bdims is None - ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" + if not (lhs_bdims is None and rhs_bdims is None): + raise RuntimeError( + f"Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims}" + ) out_bdims = (None,) return ( @@ -806,7 +843,8 @@ def _parse_operand_output_specs( for l in lhs_cspecs: for r in rhs_cspecs: if l is not None and l == r: - assert reduce_spec is None, "Multiple reduce dimension is detected!" + if reduce_spec is not None: + raise RuntimeError("Multiple reduce dimension is detected!") reduce_spec = l sequence_dim = None @@ -822,18 +860,20 @@ def _parse_operand_output_specs( " Please check your sharding configuration." ) from exc sequence_dim = tpsp_idx - assert (sequence_dim == 1) ^ transpose_batch_sequence, ( - "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)" - " or (sequence_dim=0 and transpose_batch_sequence=True). Received:" - f" sequence_dim={sequence_dim}," - f" transpose_batch_sequence={transpose_batch_sequence}." - ) + if not (sequence_dim == 1) ^ transpose_batch_sequence: + raise ValueError( + "CollectiveGEMM supports only (sequence_dim=1 and" + " transpose_batch_sequence=False) or (sequence_dim=0 and" + f" transpose_batch_sequence=True). Received: sequence_dim={sequence_dim}," + f" transpose_batch_sequence={transpose_batch_sequence}." + ) elif collective_op.is_reduce_scatter: - assert reduce_spec == gsr.tpsp_resource, ( - "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" - f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" - ) + if reduce_spec != gsr.tpsp_resource: + raise ValueError( + "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" + f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" + ) sequence_dim = int(not transpose_batch_sequence) if reduce_spec is not None: @@ -886,14 +926,18 @@ def _parse_operand_output_specs( # Only do AG Sequence dim if not Overlap RS if collective_op.is_all_gather: - assert sequence_dim <= len( - lhs_non_cspecs - ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + if sequence_dim > len(lhs_non_cspecs): + raise ValueError( + f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs:" + f" {lhs_non_cspecs}" + ) out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :] elif collective_op.is_reduce_scatter: - assert sequence_dim <= len( - lhs_non_cspecs - ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + if sequence_dim > len(lhs_non_cspecs): + raise ValueError( + f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs:" + f" {lhs_non_cspecs}" + ) out_specs = ( out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :] ) @@ -912,7 +956,8 @@ def _parse_operand_output_specs( bias_specs = rhs_non_cspecs if arg_infos[4].size > 0 else (None,) # bias is operand index 4 if not collective_op.is_none: - assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + if sequence_dim < 0: + raise ValueError(f"Invalid sequence_dim. Got sequence_dim={sequence_dim}") return ( (lhs_specs, rhs_specs, bias_specs), @@ -1154,10 +1199,11 @@ def _te_gemm( lhs_amax = rhs_amax = None # Extract GEMM custom op inputs from quantized operands if isinstance(lhs_q, ScaledTensor): - assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( - "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " - "`Quantizer` object to quantize the RHS operand." - ) + if not isinstance(rhs_q, ScaledTensor) and rhs_quantizer is None: + raise ValueError( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) if isinstance(lhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() @@ -1169,21 +1215,23 @@ def _te_gemm( lhs_amax = lhs_q.amax if isinstance(rhs_q, ScaledTensor): - assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( - "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " - "`Quantizer` object to quantize the LHS operand." - ) + if not isinstance(lhs_q, ScaledTensor) and lhs_quantizer is None: + raise ValueError( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) if isinstance(rhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() - assert ( + if not ( rhs_q.scaling_mode == lhs_q.scaling_mode or rhs_q.scaling_mode.is_nvfp4_scaling and lhs_q.scaling_mode.is_nvfp4_scaling - ), ( - "cuBLAS GEMM quantized operands have mismatched scaling types, " - f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." - ) + ): + raise ValueError( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": @@ -1193,7 +1241,8 @@ def _te_gemm( alpha = jnp.ones((1,), jnp.float32) beta = jnp.zeros((1,), jnp.float32) if scaling_mode.is_nvfp4_scaling: - assert lhs_amax is not None and rhs_amax is not None + if lhs_amax is None or rhs_amax is None: + raise ValueError("NVFP4 scaling requires non-None amax for both LHS and RHS operands") lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1268,7 +1317,10 @@ def impl( group_sizes, num_gemms, ): - assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + if GroupedGemmCopySizesPrimitive.inner_primitive is None: + raise RuntimeError( + "GroupedGemmCopySizesPrimitive.inner_primitive has not been registered" + ) out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( group_sizes, num_gemms=num_gemms, @@ -1372,23 +1424,20 @@ def abstract( shape=(int64_workspace_size,), dtype=jnp.uint8 ) - assert len(additional_args) == 2, ( - "Expected additional_args to contain alpha, beta for the graph-safe grouped GEMM" - f" primitive, but got {len(additional_args)} arguments." - ) + if len(additional_args) != 2: + raise ValueError( + "Expected additional_args to contain alpha, beta for the graph-safe grouped" + f" GEMM primitive, but got {len(additional_args)} arguments." + ) alpha_aval, beta_aval = additional_args - assert alpha_aval.shape == ( - num_groups, - ), f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}" - assert ( - alpha_aval.dtype == jnp.float32 - ), f"Expected alpha dtype float32, got {alpha_aval.dtype}" - assert beta_aval.shape == ( - num_groups, - ), f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}" - assert ( - beta_aval.dtype == jnp.float32 - ), f"Expected beta dtype float32, got {beta_aval.dtype}" + if alpha_aval.shape != (num_groups,): + raise ValueError(f"Expected alpha shape {(num_groups,)}, got {alpha_aval.shape}") + if alpha_aval.dtype != jnp.float32: + raise ValueError(f"Expected alpha dtype float32, got {alpha_aval.dtype}") + if beta_aval.shape != (num_groups,): + raise ValueError(f"Expected beta shape {(num_groups,)}, got {beta_aval.shape}") + if beta_aval.dtype != jnp.float32: + raise ValueError(f"Expected beta dtype float32, got {beta_aval.dtype}") return (out_aval, cublas_workspace_aval, setup_workspace_aval, int64_workspace_aval) @@ -1498,7 +1547,8 @@ def impl( use_async_d2h_group_sizes, use_v2_ffi, ): - assert GroupedGemmPrimitive.inner_primitive is not None + if GroupedGemmPrimitive.inner_primitive is None: + raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") if use_v2_ffi: additional_args = (additional_arg_0, additional_arg_1) else: @@ -1586,30 +1636,37 @@ def _jax_scaled_matmul( """ JAX GEMM for MXFP8 via scaled_matmul """ - assert rhs.scaling_mode in ( + if rhs.scaling_mode not in ( ScalingMode.MXFP8_1D_SCALING, ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING, - ), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}" + ): + raise ValueError( + "rhs does not have MXFP8 or NVFP4 scaling mode, got" + f" rhs.scaling_mode={rhs.scaling_mode}" + ) (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1 expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1 - assert lhs.is_colwise is expected_lhs_is_colwise, ( - f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got" - f" {lhs.is_colwise}" - ) - assert rhs.is_colwise is expected_rhs_is_colwise, ( - f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got" - f" {rhs.is_colwise}" - ) + if lhs.is_colwise is not expected_lhs_is_colwise: + raise ValueError( + f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}," + f" got {lhs.is_colwise}" + ) + if rhs.is_colwise is not expected_rhs_is_colwise: + raise ValueError( + f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}," + f" got {rhs.is_colwise}" + ) if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: out_dtype = lhs.dq_dtype - assert ( - lhs.data_layout == "N" and rhs.data_layout == "N" - ), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + if not (lhs.data_layout == "N" and rhs.data_layout == "N"): + raise ValueError( + f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + ) else: if lhs.data_layout == "T": lhs_contract = transpose_dims( @@ -1641,7 +1698,8 @@ def _jax_scaled_matmul( lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) if lhs.scaling_mode.is_nvfp4_scaling: - assert lhs.amax is not None and rhs.amax is not None + if lhs.amax is None or rhs.amax is None: + raise ValueError("NVFP4 scaling requires non-None amax for both LHS and RHS operands") lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1674,9 +1732,10 @@ def _jax_gemm( def _jax_gemm_impl(lhs, rhs): if lhs.scaling_mode.is_tensor_scaling(): - assert ( - rhs.scaling_mode == lhs.scaling_mode - ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" + if rhs.scaling_mode != lhs.scaling_mode: + raise ValueError( + f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" + ) precision = ( jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT @@ -1760,7 +1819,8 @@ def gemm( # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled if not GemmPrimitive.enabled(): - assert collective_op.is_none, "JAX GEMM does not support collective GEMM" + if not collective_op.is_none: + raise RuntimeError("JAX GEMM does not support collective GEMM") output = _jax_gemm( lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer, use_split_accumulator ) @@ -1857,7 +1917,10 @@ def grouped_gemm( del precision if isinstance(lhs, jnp.ndarray): - assert isinstance(rhs, jnp.ndarray) + if not isinstance(rhs, jnp.ndarray): + raise TypeError( + f"Expected rhs to be jnp.ndarray when lhs is jnp.ndarray, but got type={type(rhs)}" + ) out_dtype = lhs.dtype lhs_shape = lhs.shape rhs_shape = rhs.shape @@ -1866,7 +1929,11 @@ def grouped_gemm( lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING elif isinstance(lhs, GroupedScaledTensor1x): - assert isinstance(rhs, GroupedScaledTensor1x) + if not isinstance(rhs, GroupedScaledTensor1x): + raise TypeError( + "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but" + f" got type={type(rhs)}" + ) out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape @@ -1874,7 +1941,11 @@ def grouped_gemm( rhs_data = rhs.data lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv - assert lhs.scaling_mode == rhs.scaling_mode + if lhs.scaling_mode != rhs.scaling_mode: + raise ValueError( + f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," + f" rhs.scaling_mode={rhs.scaling_mode}" + ) scaling_mode = lhs.scaling_mode else: raise TypeError("Unsupported lhs type object!") @@ -1911,8 +1982,16 @@ def grouped_gemm( and not isinstance(rhs, ScaledTensor) and quantizer_set != noop_quantizer_set ): - assert isinstance(quantizer_set.x, GroupedQuantizer) - assert type(quantizer_set.x) is type(quantizer_set.kernel) + if not isinstance(quantizer_set.x, GroupedQuantizer): + raise TypeError( + "Expected quantizer_set.x to be GroupedQuantizer, but got" + f" type={type(quantizer_set.x)}" + ) + if type(quantizer_set.x) is not type(quantizer_set.kernel): + raise TypeError( + "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" + f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" + ) scaling_mode = quantizer_set.x.scaling_mode if ( quantizer_set.x.scaling_mode.is_tensor_scaling() @@ -1939,9 +2018,8 @@ def grouped_gemm( lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape - assert not ( - lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 - ), "FP8 GEMM does not support E5M2 * E5M2" + if lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2: + raise ValueError("FP8 GEMM does not support E5M2 * E5M2") # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs # thus additional transpose is required @@ -1954,12 +2032,10 @@ def grouped_gemm( rhs_layout_is_T = rhs_q.data_layout == "T" # we can't apply _shape_normalization on the grouped input # thus we need to ensure that lhs is in N and rhs is in T - assert ( - lhs_is_trans == lhs_layout_is_T - ), "lhs input must be transposed before calling grouped_gemm" - assert ( - not rhs_is_trans == rhs_layout_is_T - ), "rhs input must be transposed before calling grouped_gemm" + if lhs_is_trans != lhs_layout_is_T: + raise RuntimeError("lhs input must be transposed before calling grouped_gemm") + if (not rhs_is_trans) != rhs_layout_is_T: + raise RuntimeError("rhs input must be transposed before calling grouped_gemm") lhs_is_trans = False rhs_is_trans = True lhs_ndim = len(lhs_shape) @@ -1978,28 +2054,36 @@ def grouped_gemm( # Calling GroupedGEMM Custom Call K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs + if K_lhs != K_rhs: + raise ValueError( + f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from" + f" lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + ) M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G if is_grouped_dense_wgrad: N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) else: - assert group_sizes.size == rhs_shape[0] + if group_sizes.size != rhs_shape[0]: + raise ValueError( + "Expected group_sizes.size == rhs_shape[0], but got" + f" group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" + ) has_bias = bias is not None - if has_bias: - assert bias.shape == ( - group_sizes.size, - N, - ), f"bias shape {bias.shape} does not match expected shape {(group_sizes.size, N)}" + if has_bias and bias.shape != (group_sizes.size, N): + raise ValueError( + f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" + ) bias = jnp.empty((), jnp.float32) if bias is None else bias - assert group_offset is None, ( - "group_offset is not supported yet and is instead computed" - " internally assuming contiguous grouping. Any padding is included in the group_sizes" - " and padded with zeros to not affect the result of the MoE block." - ) + if group_offset is not None: + raise RuntimeError( + "group_offset is not supported yet and is instead computed" + " internally assuming contiguous grouping. Any padding is included in the group_sizes" + " and padded with zeros to not affect the result of the MoE block." + ) use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) if use_v2_ffi: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 70fdf4c474..29292f946b 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -132,9 +132,17 @@ def abstract( ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval is None or scale_aval.dtype == jnp.float32 - assert amax_aval is None or amax_aval.dtype == jnp.float32 + assert x_dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + ], f"Unsupported x_dtype={x_dtype}, expected one of [float32, float16, bfloat16]" + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert ( + amax_aval is None or amax_aval.dtype == jnp.float32 + ), f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -159,7 +167,10 @@ def abstract( mu_rsigama_dtype = jnp.float32 if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.size == beta_aval.size + assert gamma_aval.size == beta_aval.size, ( + "Expected gamma_aval.size == beta_aval.size, but got" + f" gamma_aval.size={gamma_aval.size}, beta_aval.size={beta_aval.size}" + ) assert gamma_aval.dtype == beta_aval.dtype, ( f"gamma and beta should have the same dtype, but got {gamma_aval.dtype} and " f"{beta_aval.dtype}" @@ -265,18 +276,35 @@ def lowering( del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval is None or scale_aval.dtype == jnp.float32 - assert amax_aval is None or amax_aval.dtype == jnp.float32 + assert x_aval.dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + ], f"Unsupported x_aval.dtype={x_aval.dtype}, expected one of [float32, float16, bfloat16]" + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert ( + amax_aval is None or amax_aval.dtype == jnp.float32 + ), f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.dtype == beta_aval.dtype + assert gamma_aval.dtype == beta_aval.dtype, ( + "Expected gamma and beta to have the same dtype, but got" + f" gamma_aval.dtype={gamma_aval.dtype}, beta_aval.dtype={beta_aval.dtype}" + ) b_type = ir.RankedTensorType(beta.type) b_shape = b_type.shape - assert g_type == b_type - assert g_shape == b_shape + assert g_type == b_type, ( + f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}," + f" beta_type={b_type}" + ) + assert g_shape == b_shape, ( + f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}," + f" beta_shape={b_shape}" + ) sm_margin = get_forward_sm_margin() return ffi.ffi_lowering( @@ -321,7 +349,9 @@ def impl( to describe implementation """ del is_outer - assert NormFwdPrimitive.inner_primitive is not None + assert ( + NormFwdPrimitive.inner_primitive is not None + ), "NormFwdPrimitive.inner_primitive has not been registered" ( out, colwise_out, @@ -391,7 +421,9 @@ def batcher( to describe batch rules for vmap """ check_valid_batch_dims(batch_dims) - assert NormFwdPrimitive.outer_primitive is not None + assert ( + NormFwdPrimitive.outer_primitive is not None + ), "NormFwdPrimitive.outer_primitive has not been registered" x, scale, amax, gamma, beta = batched_args x_bdim, scale_bdim, _, _, _ = batch_dims @@ -706,13 +738,26 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) - assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype - assert dz_aval.shape == x_aval.shape + assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype, ( + f"Expected dz_aval.dtype={w_dtype} (matching gamma dtype), but got" + f" dz_aval.dtype={dtypes.canonicalize_dtype(dz_aval.dtype)}" + ) + assert dz_aval.shape == x_aval.shape, ( + f"Expected dz_aval.shape == x_aval.shape, but got dz_aval.shape={dz_aval.shape}," + f" x_aval.shape={x_aval.shape}" + ) if norm_type == NVTE_Norm_Type.LayerNorm: mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) - assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] - assert mu_dtype == rsigma_dtype == jnp.float32 + assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], ( + "Expected mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], but got" + f" mu_aval.shape={mu_aval.shape}, rsigma_aval.shape={rsigma_aval.shape}," + f" x_aval.shape[:-1]={x_aval.shape[:-1]}" + ) + assert mu_dtype == rsigma_dtype == jnp.float32, ( + f"Expected mu_dtype == rsigma_dtype == float32, but got mu_dtype={mu_dtype}," + f" rsigma_dtype={rsigma_dtype}" + ) dx_aval = dz_aval dgamma_aval = dbeta_aval = gamma_aval @@ -756,8 +801,14 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): g_shape = g_type.shape b_type = ir.RankedTensorType(gamma.type) b_shape = b_type.shape - assert g_type == b_type - assert g_shape == b_shape + assert g_type == b_type, ( + f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}," + f" beta_type={b_type}" + ) + assert g_shape == b_shape, ( + f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}," + f" beta_shape={b_shape}" + ) sm_margin = get_backward_sm_margin() return ffi.ffi_lowering(NormBwdPrimitive.name)( @@ -774,7 +825,9 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): @staticmethod def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): - assert NormBwdPrimitive.inner_primitive is not None + assert ( + NormBwdPrimitive.inner_primitive is not None + ), "NormBwdPrimitive.inner_primitive has not been registered" dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma ) @@ -783,7 +836,9 @@ def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): @staticmethod def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): check_valid_batch_dims(batch_dims) - assert NormBwdPrimitive.outer_primitive is not None + assert ( + NormBwdPrimitive.outer_primitive is not None + ), "NormBwdPrimitive.outer_primitive has not been registered" dz, x, mu, rsigma, gamma = batched_args _, x_bdim, _, _, gamma_bdim = batch_dims diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 1fce6d2fd7..031ab483a0 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -115,7 +115,10 @@ def impl( score_function, compute_aux_scores, ): - assert FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is not None + if FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive has not been registered" + ) return FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind( logits, expert_bias, @@ -141,7 +144,10 @@ def batcher( score_function, compute_aux_scores, ): - assert FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is not None + if FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive has not been registered" + ) logits, expert_bias = batched_args logits_bdim, _ = batch_dims return ( @@ -284,7 +290,10 @@ def impl( score_function, compute_aux_scores, ): - assert FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is not None + if FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive has not been registered" + ) return FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive.bind( routing_map, intermediate, @@ -307,7 +316,10 @@ def batcher( score_function, compute_aux_scores, ): - assert FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is not None + if FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive has not been registered" + ) routing_map, intermediate, grad_probs = batched_args _, _, grad_probs_bdim = batch_dims return ( @@ -402,7 +414,10 @@ def lowering(ctx, probs, tokens_per_expert, *, topk, coeff): @staticmethod def impl(probs, tokens_per_expert, topk, coeff): - assert FusedMoEAuxLossFwdPrimitive.inner_primitive is not None + if FusedMoEAuxLossFwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossFwdPrimitive.inner_primitive has not been registered" + ) return FusedMoEAuxLossFwdPrimitive.inner_primitive.bind( probs, tokens_per_expert, @@ -412,7 +427,10 @@ def impl(probs, tokens_per_expert, topk, coeff): @staticmethod def batcher(batched_args, batch_dims, *, topk, coeff): - assert FusedMoEAuxLossFwdPrimitive.outer_primitive is not None + if FusedMoEAuxLossFwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossFwdPrimitive.outer_primitive has not been registered" + ) probs, tokens_per_expert = batched_args probs_bdim, _ = batch_dims return ( @@ -490,7 +508,10 @@ def lowering(ctx, const_buf, tokens_per_expert, grad_aux_loss, *, num_tokens): @staticmethod def impl(const_buf, tokens_per_expert, grad_aux_loss, num_tokens): - assert FusedMoEAuxLossBwdPrimitive.inner_primitive is not None + if FusedMoEAuxLossBwdPrimitive.inner_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossBwdPrimitive.inner_primitive has not been registered" + ) return FusedMoEAuxLossBwdPrimitive.inner_primitive.bind( const_buf, tokens_per_expert, @@ -500,7 +521,10 @@ def impl(const_buf, tokens_per_expert, grad_aux_loss, num_tokens): @staticmethod def batcher(batched_args, batch_dims, *, num_tokens): - assert FusedMoEAuxLossBwdPrimitive.outer_primitive is not None + if FusedMoEAuxLossBwdPrimitive.outer_primitive is None: + raise RuntimeError( + "FusedMoEAuxLossBwdPrimitive.outer_primitive has not been registered" + ) const_buf, tokens_per_expert, grad_aux_loss = batched_args _, _, grad_bdim = batch_dims return ( diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 7decfca6c6..31ce6e72e9 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1445,7 +1445,8 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_grouped_dense_cls(quantization_recipe): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" - assert quantization_recipe is None, "Ragged dot grouped GEMM does not support quantization yet" + if quantization_recipe is not None: + raise ValueError("Ragged dot grouped GEMM does not support quantization yet") def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): del kwargs # Unused diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad5a60e4c2..513677e4a1 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -182,7 +182,9 @@ def __call__( is_gqa = h_q != h_kv if is_gqa: - assert (h_q % h_kv == 0) and (h_q >= h_kv) + assert (h_q % h_kv == 0) and ( + h_q >= h_kv + ), f"num_query_heads ({h_q}) must be divisible by and >= num_kv_heads ({h_kv})" group_size = h_q // h_kv grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) @@ -428,7 +430,9 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) - assert x.dtype == query.dtype + assert ( + x.dtype == query.dtype + ), f"output dtype {x.dtype} does not match query dtype {query.dtype}" return x @@ -713,9 +717,13 @@ def __call__( del self.attn_bias_type, self.attn_mask_type, self.qkv_layout if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" else: - assert bias is not None + assert ( + bias is not None + ), f"bias must not be None when attn_bias_type is {attn_bias_type}" bias = bias.astype(input_dtype) self._assert_dtypes(query, key, value, qkv_layout) @@ -823,11 +831,13 @@ def __call__( key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: - assert qkv_layout.is_separate() + assert ( + qkv_layout.is_separate() + ), f"Expected separate qkv_layout, but got {qkv_layout}" assert sequence_descriptor is None or isinstance( sequence_descriptor, (jnp.ndarray, np.ndarray) - ) + ), f"sequence_descriptor must be None or ndarray, but got {type(sequence_descriptor)}" x = _UnfusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -994,7 +1004,7 @@ def _canonicalize_lora_scope(scope): SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP, - ] + ], f"Unsupported LoRA scope: {scope}" lora_scope = LoRAScope() @@ -1307,8 +1317,10 @@ def query_init(*args): return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) def qkv_init(key, shape, dtype): - assert len(shape) == 3 - assert shape[-2] == 3 + assert ( + len(shape) == 3 + ), f"qkv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert shape[-2] == 3, f"qkv_init expects shape[-2] == 3, but got shape={shape}" q_key, k_key, v_key = jax_random.split(key, num=3) @@ -1323,8 +1335,8 @@ def qkv_init(key, shape, dtype): return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype) def kv_init(key, shape, dtype): - assert len(shape) == 3 - assert shape[-2] == 2 + assert len(shape) == 3, f"kv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert shape[-2] == 2, f"kv_init expects shape[-2] == 2, but got shape={shape}" k_key, v_key = jax_random.split(key) @@ -1415,7 +1427,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(inputs_q) if is_self_attn: - assert ln_out is not None + assert ln_out is not None, "ln_out must not be None for self-attention" inputs_kv = ln_out kv_proj = DenseGeneral( @@ -1475,7 +1487,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(inputs_q) if is_self_attn: - assert ln_out is not None + assert ln_out is not None, "ln_out must not be None for self-attention" inputs_kv = ln_out query = query.astype(input_dtype) @@ -1494,7 +1506,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): elif qkv_layout == QKVLayout.BSHD_BS2HD: key, value = jnp.split(kv_proj, [1], axis=-2) else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact) query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) @@ -1520,7 +1534,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) if decode: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"decode mode requires QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) @@ -1588,7 +1604,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) dpa_args = [query, kv_proj, None] else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) @@ -2101,7 +2119,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): l = inputs.shape[sequence_dim] attn_bias = rel_emb(l, l, False) - assert inputs.ndim == 3 + assert ( + inputs.ndim == 3 + ), f"inputs must be 3D (batch, sequence, hidden), but got {inputs.ndim}D" # Make name be the exactly same as T5X, since names would affect # RNGKey during init and apply. Myabe no need in the feature. @@ -2151,10 +2171,15 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) def hidden_dropout(x, deterministic): - assert isinstance(self.hidden_dropout_dims, Sequence) + assert isinstance( + self.hidden_dropout_dims, Sequence + ), f"hidden_dropout_dims must be a Sequence, but got {type(self.hidden_dropout_dims)}" x_shape_len = len(x.shape) for dims in self.hidden_dropout_dims: - assert -x_shape_len <= dims < x_shape_len + assert -x_shape_len <= dims < x_shape_len, ( + f"hidden_dropout_dims value {dims} is out of range " + f"[{-x_shape_len}, {x_shape_len}) for input with {x_shape_len} dimensions" + ) return nn.Dropout( rate=self.hidden_dropout, @@ -2179,7 +2204,9 @@ def hidden_dropout(x, deterministic): )(x, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out x = x + residual @@ -2239,7 +2266,9 @@ def hidden_dropout(x, deterministic): y = hidden_dropout(y, deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out mlp_input = y + residual @@ -2284,7 +2313,9 @@ def hidden_dropout(x, deterministic): )(mlp_input, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out z = with_sharding_constraint_by_logical_axes( diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 3f3f3802db..0f173a89e3 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -31,7 +31,11 @@ def canonicalize_norm_type(x): Canonicalized normalization type string """ canonicalized = x.lower().strip().replace("-", "").replace("_", "") - assert canonicalized in ["layernorm", "rmsnorm"] + if canonicalized not in ["layernorm", "rmsnorm"]: + raise ValueError( + f"Unsupported normalization type '{x}' (canonicalized: '{canonicalized}'). " + "Valid options are: 'layernorm', 'rmsnorm'." + ) return canonicalized diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index e9f64bb693..2de4576e05 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -271,14 +271,23 @@ def fused_attn_fwd( attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + if attn_bias is None: + raise ValueError( + f"attn_bias tensor cannot be None when attn_bias_type={attn_bias_type!r}." + ) + if attn_bias.dtype != q.dtype: + raise ValueError( + "attn_bias tensor must have the same dtype as q and kv: " + f"attn_bias.dtype={attn_bias.dtype} but q.dtype={q.dtype}." + ) + + if fused_attention_backend == FusedAttnBackend["No_Backend"]: + raise ValueError( + "Fused attention does not support this input combination:" + f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," + f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," + f" q.dtype={q.dtype}, backend={fused_attention_backend}." + ) # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: @@ -294,12 +303,16 @@ def fused_attn_fwd( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." - assert ( - o_quantizer is not None - ), "o_quantizer is required as an input for FP8 fused attention." + if s_quantizer is None: + raise ValueError( + "s_quantizer is required for FP8 fused attention forward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if o_quantizer is None: + raise ValueError( + "o_quantizer is required for FP8 fused attention forward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -488,28 +501,44 @@ def fused_attn_bwd( d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + if fused_attention_backend == FusedAttnBackend["No_Backend"]: + raise ValueError( + "Fused attention backward does not support this input combination:" + f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," + f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," + f" q.dtype={q.dtype}, backend={fused_attention_backend}." + ) if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." + if len(aux_ctx_tensors) < 1: + raise ValueError( + "aux_ctx_tensors must contain rng_state as its last element," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" for backend={fused_attention_backend}." + ) if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention backward." - assert ( - dp_quantizer is not None - ), "dp_quantizer is required as an input for FP8 fused attention backward." - assert ( - dqkv_dtype is not None - ), "dqkv_dtype is required as an input for FP8 fused attention backward." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + if s_quantizer is None: + raise ValueError( + "s_quantizer is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if dp_quantizer is None: + raise ValueError( + "dp_quantizer is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if dqkv_dtype is None: + raise ValueError( + "dqkv_dtype is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + if len(aux_ctx_tensors) != 3: + raise ValueError( + "aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" (backend={fused_attention_backend})." + ) output_tensors = tex.fused_attn_bwd( max_seqlen_q, diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 05219b7b18..9da2f889a4 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -124,7 +124,11 @@ def tensor_group_process_after_reload(tensor_group: TensorGroup): """ Call for a tensor group, just after reload logic. """ - assert tensor_group.aux is not None + if tensor_group.aux is None: + raise RuntimeError( + "TensorGroup.aux must be set before post-reload processing, " + f"but got aux=None for tensor_group with {len(tensor_group.tensor_list)} tensors" + ) tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group) tensor_group = TensorGroupProcessor._switch_to_views(tensor_group) return tensor_group @@ -158,9 +162,8 @@ def _check_if_offload_base_tensor(tensor: torch.Tensor) -> bool: if _check_if_offload_base_tensor(tensor): aux["views"].append((tensor.shape, tensor.stride(), tensor.storage_offset())) tensor = tensor._base - assert ( - tensor is not None - ), "Cannot offload base tensor, if the tensor is not a view." + if tensor is None: + raise RuntimeError("Cannot offload base tensor, if the tensor is not a view.") tensor_group.tensor_list[tensor_id] = tensor else: aux["views"].append(None) @@ -247,9 +250,10 @@ def __init__( self.state = "not_offloaded" def _validate_state(self, func_name: str, allowed_states: list[str]): - assert ( - self.state in allowed_states - ), f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}" + if self.state not in allowed_states: + raise RuntimeError( + f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}" + ) def start_offload(self): """ @@ -271,7 +275,12 @@ def start_offload(self): ) for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list): - assert tensor.is_contiguous() + if not tensor.is_contiguous(): + raise ValueError( + f"Tensor at index {tensor_id} must be contiguous for CPU offloading, " + f"but got non-contiguous tensor with shape={tensor.shape}, " + f"stride={tensor.stride()}, dtype={tensor.dtype}" + ) # Wait for the moment the tensor is ready to be offloaded. self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type] @@ -284,12 +293,13 @@ def start_offload(self): self.cpu_tensor_group.tensor_list.append(offloaded_tensor) else: offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] - assert offloaded_tensor.shape == tensor.shape, ( - "CPU buffer shape does not match the offloaded tensor shape:" - f" {offloaded_tensor.shape} != {tensor.shape} " - "Make sure that tensor shapes do not change between" - " iterations if retain_pinned_cpu_buffers is True." - ) + if offloaded_tensor.shape != tensor.shape: + raise ValueError( + "CPU buffer shape does not match the offloaded tensor shape:" + f" {offloaded_tensor.shape} != {tensor.shape} " + "Make sure that tensor shapes do not change between" + " iterations if retain_pinned_cpu_buffers is True." + ) offloaded_tensor.copy_(tensor, non_blocking=True) # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, @@ -420,7 +430,11 @@ def pop_tensor( return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] # 4. the layer was offloaded - assert self.state == "reload_started" + if self.state != "reload_started": + raise RuntimeError( + "Expected state='reload_started' when popping an offloaded tensor, " + f"but got state='{self.state}' for tensor={tensor_or_tensor_id}" + ) # wait for the tensor to be reloaded torch.cuda.current_stream().wait_event( self.bwd_gpu_tensor_group.events[tensor_or_tensor_id] @@ -824,18 +838,19 @@ def get_cpu_offload_context( raise RuntimeError("CPU offload is not supported in debug mode.") if not manual_synchronization: - assert ( - num_layers <= model_layers - 1 - ), "Cannot offload all layers without manual synchronization - last layer is not offloaded." + if num_layers > model_layers - 1: + raise ValueError( + "Cannot offload all layers without manual synchronization - last layer is not" + f" offloaded. Got num_layers={num_layers}, model_layers={model_layers}." + ) if num_layers == model_layers - 1: warnings.warn( "Offloading num_layers == model_layers - 1 is not recommended, it prevents" " overlapping of computation and offload/reload." ) - assert ( - offload_stream is None or manual_synchronization - ), "offload_stream can be provided only if manual_synchronization is True" + if offload_stream is not None and not manual_synchronization: + raise ValueError("offload_stream can be provided only if manual_synchronization is True") if manual_synchronization: offload_synchronizer = ManualOffloadSynchronizer( @@ -858,9 +873,10 @@ def __init__(self): self.inside_context = False def __enter__(self): - assert ( - self.inside_context is False - ), "Offloading context was entered without synchronization function being called." + if self.inside_context: + raise RuntimeError( + "Offloading context was entered without synchronization function being called." + ) self.inside_context = True self._hooks_ctx = saved_tensors_hooks( offload_synchronizer.push_tensor, offload_synchronizer.pop_tensor @@ -882,12 +898,23 @@ def synchronization_function(self, tensor): """ This function is used to catch the backward pass of the model. """ - assert tensor.requires_grad is True - assert self.current_layer is not None + if not tensor.requires_grad: + raise ValueError( + "Tensor passed to synchronization_function must require grad to " + "register backward hooks, but got requires_grad=False for tensor " + f"with shape={tensor.shape}, dtype={tensor.dtype}" + ) + if self.current_layer is None: + raise RuntimeError( + "synchronization_function called but no layer has been set via __enter__. " + f"inside_context={self.inside_context}, " + f"offload_synchronizer num_layers={self.offload_synchronizer.num_layers}" + ) cur_layer = self.current_layer - assert ( - self.inside_context is False - ), "Synchronization function was called without offloading context being entered." + if self.inside_context: + raise RuntimeError( + "Synchronization function was called without offloading context being entered." + ) def hook(_): # offload_synchronizer.finish_part_of_bwd needs diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 645dbb48d2..b06f6f5619 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -272,7 +272,8 @@ at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType } else if (size == 1) { return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); } - NVTE_CHECK(false, "Should never reach here! func: allocateSpace"); + NVTE_ERROR("Unsupported tensor allocation: ndim=", size, ", init_to_zeros=", init_to_zeros, + ". Only 1D and 2D tensors are supported."); } at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0214f7ff71..7e13cc105f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2184,7 +2184,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { - NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + NVTE_ERROR("RHT is only supported for bfloat16 input, got dtype enum value ", + static_cast(input.dtype())); } if (this->with_post_rht_amax) { // We need: @@ -2196,7 +2197,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou }); } else { // raise error since it's not supported yet - NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + NVTE_ERROR( + "Pre-RHT amax is not supported yet. " + "Use with_post_rht_amax=true instead."); } } else { // Without RHT if (compute_amax) { diff --git a/transformer_engine/pytorch/custom_recipes/gemm.py b/transformer_engine/pytorch/custom_recipes/gemm.py index 8f853ff093..3d1e1cc43e 100644 --- a/transformer_engine/pytorch/custom_recipes/gemm.py +++ b/transformer_engine/pytorch/custom_recipes/gemm.py @@ -32,7 +32,8 @@ def custom_gemm( grad: bool = False, ) -> Iterable[Optional[torch.Tensor]]: """Dispatch GEMM to quantizer's qgemm method.""" - assert is_custom(A) and is_custom(B), "A and B must be custom tensors" + if not (is_custom(A) and is_custom(B)): + raise TypeError("A and B must be custom tensors") A, B = B, A @@ -68,11 +69,16 @@ def custom_gemm( if gemm_type == GEMMType.FPROP: qx, sx = A.data, A.scale qw, sw = B.data, B.scale - assert qx is not None - assert sx is not None - assert qw is not None - assert sw is not None - assert A.original_shape is not None + if qx is None: + raise ValueError("FPROP GEMM: quantized activation data (A.data) is None") + if sx is None: + raise ValueError("FPROP GEMM: activation scale (A.scale) is None") + if qw is None: + raise ValueError("FPROP GEMM: quantized weight data (B.data) is None") + if sw is None: + raise ValueError("FPROP GEMM: weight scale (B.scale) is None") + if A.original_shape is None: + raise ValueError("FPROP GEMM: A.original_shape is None, cannot determine output shape") # Call quantizer's qgemm method result = quantizer.qgemm( @@ -95,10 +101,14 @@ def custom_gemm( elif gemm_type == GEMMType.DGRAD: qdy, sdy = A.data, A.scale qw_t, sw_t = B.data_t, B.scale_t - assert qdy is not None - assert sdy is not None - assert qw_t is not None - assert sw_t is not None + if qdy is None: + raise ValueError("DGRAD GEMM: quantized gradient data (A.data) is None") + if sdy is None: + raise ValueError("DGRAD GEMM: gradient scale (A.scale) is None") + if qw_t is None: + raise ValueError("DGRAD GEMM: transposed quantized weight data (B.data_t) is None") + if sw_t is None: + raise ValueError("DGRAD GEMM: transposed weight scale (B.scale_t) is None") result = quantizer.qgemm( qdy, @@ -115,10 +125,14 @@ def custom_gemm( elif gemm_type == GEMMType.WGRAD: qdy_t, sdy_t = A.data_t, A.scale_t qx_t, sx_t = B.data_t, B.scale_t - assert qdy_t is not None - assert sdy_t is not None - assert qx_t is not None - assert sx_t is not None + if qdy_t is None: + raise ValueError("WGRAD GEMM: transposed quantized gradient data (A.data_t) is None") + if sdy_t is None: + raise ValueError("WGRAD GEMM: transposed gradient scale (A.scale_t) is None") + if qx_t is None: + raise ValueError("WGRAD GEMM: transposed quantized activation data (B.data_t) is None") + if sx_t is None: + raise ValueError("WGRAD GEMM: transposed activation scale (B.scale_t) is None") result = quantizer.qgemm( qdy_t, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d00d0c8b94..f42183ec09 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -169,7 +169,8 @@ def high_precision_gemm_ref( y_shape = (mat1.size(0), mat2.size(1)) if bias is not None: - assert not accumulate, "Bias is not supported with accumulation" + if accumulate: + raise ValueError("Bias is not supported with accumulation") bias = bias.to(out_dtype) # With bias case if out_dtype == torch.float32: @@ -325,7 +326,8 @@ def size(self, *args, **kwargs): # pylint: disable=unused-argument the second dimension by half. This method returns the logical shape that users expect, not the internal packed storage shape. """ - assert self.original_shape is not None + if self.original_shape is None: + raise RuntimeError("NVFP4TensorRef.size() called but original_shape has not been set") return torch.Size(self.original_shape) @@ -374,7 +376,8 @@ def _build_hadamard_matrix( Uses Sylvester construction to avoid SciPy dependency. """ - assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + if (size & (size - 1)) != 0: + raise ValueError(f"Hadamard size must be a power of two, got {size}") h = torch.ones((1, 1), device=device, dtype=torch.float32) while h.shape[0] < size: h = torch.cat( @@ -402,9 +405,10 @@ def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: # RHT dimension equals the quantization tile length (NVFP4 uses 16) rht_dim = self.quant_tile_shape[1] - assert ( - x.shape[-1] % rht_dim == 0 - ), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + if x.shape[-1] % rht_dim != 0: + raise ValueError( + f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + ) # Build H and scale H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask) @@ -446,7 +450,11 @@ def _quantize_blockwise_reference( eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.ndim == 2 + if x.ndim != 2: + raise ValueError( + f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape" + f" {x.shape}" + ) using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 m, n = x.shape # Compute vec_max based on the original x (before reshape) @@ -525,7 +533,11 @@ def _pad_tensor( tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] ) -> torch.Tensor: - assert tensor.dim() == 2, "only supports 2D tensors" + if tensor.dim() != 2: + raise ValueError( + f"_pad_tensor only supports 2D tensors, got {tensor.dim()}D tensor with shape" + f" {tensor.shape}" + ) M, N = tensor.shape padding_needed_rows = 0 padding_needed_cols = 0 @@ -553,7 +565,11 @@ def _pad_tensor( @staticmethod def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor: - assert tensor.dim() == 2, "only supports 2D tensors" + if tensor.dim() != 2: + raise ValueError( + f"_rm_pad_tensor only supports 2D tensors, got {tensor.dim()}D tensor with shape" + f" {tensor.shape}" + ) M, N = original_size out = tensor[:M, :N].contiguous() return out @@ -584,19 +600,20 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise - global_amax_row, global_amax_col: global amax tensors """ + global_amax_col = None if self.pow_2_scales: - assert self.quant_tile_shape == ( - 1, - 32, - ), "MXFP4 only supports 1x32 tile shape." + if self.quant_tile_shape != (1, 32): + raise ValueError( + f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" + ) # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) else: - assert self.quant_tile_shape in ( - (1, 16), - (16, 16), - ), "NVFP4 only supports 1x16 or 16x16 tile shape." + if self.quant_tile_shape not in ((1, 16), (16, 16)): + raise ValueError( + f"NVFP4 only supports 1x16 or 16x16 tile shape, got {self.quant_tile_shape}" + ) # Prepare inputs once so we can reuse for both amax and quantization # Row-input will always be the original input. row_input = tensor @@ -670,7 +687,11 @@ def quantize( **kwargs, # pylint: disable=unused-argument ) -> NVFP4TensorRef: # sanity checks - assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype." + if tensor.dtype not in utils.HIGH_PRECISION_FLOAT_DTYPES: + raise TypeError( + f"Unsupported input dtype {tensor.dtype}, expected one of" + f" {utils.HIGH_PRECISION_FLOAT_DTYPES}" + ) # Make it work with 3D tensors original_shape = tensor.shape @@ -766,7 +787,10 @@ def is_data_t_transposed_in_memory(self) -> bool: TODO(etsykunov): Confirm docstring is correct. """ - raise NotImplementedError("Not implemented yet") + raise NotImplementedError( + "NVFP4QuantizerRef.is_data_t_transposed_in_memory is not implemented for FP4" + " quantization" + ) def qgemm( self, @@ -784,7 +808,8 @@ def qgemm( qresult_w: QuantizedTensorStorage | None = None, ) -> torch.Tensor: """Python implementation of microblock FP4 GEMM.""" - assert bias is None, "Bias is implemented for FP4 GEMM." + if bias is not None: + raise ValueError("Bias is not supported in NVFP4QuantizerRef.qgemm") high_precision_x = cast_from_fp4x2(qx, out_dtype) high_precision_w = cast_from_fp4x2(qw, out_dtype) @@ -814,11 +839,22 @@ def qgemm( else: - assert qresult_x is not None - assert qresult_w is not None - - assert qresult_x.global_amax_row is not None - assert qresult_w.global_amax_col is not None + if qresult_x is None: + raise ValueError( + "qresult_x is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + ) + if qresult_w is None: + raise ValueError( + "qresult_w is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + ) + if qresult_x.global_amax_row is None: + raise ValueError( + "qresult_x.global_amax_row must be set for non-pow_2_scales NVFP4 GEMM" + ) + if qresult_w.global_amax_col is None: + raise ValueError( + "qresult_w.global_amax_col must be set for non-pow_2_scales NVFP4 GEMM" + ) sx = sx.to(torch.float32) sw = sw.to(torch.float32) @@ -833,23 +869,27 @@ def qgemm( M, K = high_precision_x.shape N, K_w = high_precision_w.shape - assert K == K_w, "K dimension mismatch between qx and qw" - - assert K % 32 == 0, "K dimension must be divisible by 32" - assert N % 8 == 0, "N dimension must be divisible by 8" + if K != K_w: + raise ValueError( + f"K dimension mismatch between qx and qw: qx has K={K}, qw has K={K_w}" + ) + if K % 32 != 0: + raise ValueError(f"K dimension must be divisible by 32, got K={K}") + if N % 8 != 0: + raise ValueError(f"N dimension must be divisible by 8, got N={N}") block_length = 32 if self.pow_2_scales else 16 grid_k = K // block_length - assert sx.shape == ( - M, - K // block_length, - ), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}" - assert sw.shape == ( - N, - K // block_length, - ), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}" + if sx.shape != (M, K // block_length): + raise ValueError( + f"sx shape mismatch: expected ({M}, {K // block_length}), got {sx.shape}" + ) + if sw.shape != (N, K // block_length): + raise ValueError( + f"sw shape mismatch: expected ({N}, {K // block_length}), got {sw.shape}" + ) y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) @@ -878,10 +918,12 @@ def qgemm( # accumulation happens at epilogue in float32 if accumulate: - assert out is not None, "Output tensor must be provided for accumulation." + if out is None: + raise ValueError("Output tensor must be provided for accumulation.") y += out.to(torch.float32) else: - assert out is None, "Output tensor should be None when accumulate is False." + if out is not None: + raise ValueError("Output tensor should be None when accumulate is False.") y = y.to(out_dtype) return y diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 2a65fa272b..03ac9c1595 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -152,7 +152,11 @@ def set_tensor_model_parallel_attributes( ) -> None: """set attributes needed for TP""" for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) + if hasattr(tensor, attribute): + raise RuntimeError( + f"Tensor already has attribute '{attribute}' set. Cannot set " + "tensor model parallel attributes on a tensor that already has them." + ) # Set the attributes. setattr(tensor, "tensor_model_parallel", is_parallel) setattr(tensor, "partition_dim", dim) @@ -170,7 +174,11 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: @lru_cache def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: """Return my rank for the distributed group.""" - assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call torch.distributed.init_process_group() " + "before calling get_distributed_rank()." + ) return torch.distributed.get_rank(group=group) @@ -743,7 +751,12 @@ def checkpoint( # If saved activations need to be distributed but there is no process group, # default to the world group. if distribute_saved_activations: - assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call " + "torch.distributed.init_process_group() before using " + "distribute_saved_activations=True." + ) tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group return _CheckpointFunction.apply( @@ -917,9 +930,12 @@ def reduce_scatter_along_first_dim( return inp, None dim_size = list(inp.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" + if dim_size[0] % world_size != 0: + raise ValueError( + "First dimension of the tensor should be divisible by tensor parallel size, " + f"but got dim_size[0]={dim_size[0]} and world_size={world_size} " + f"(remainder={dim_size[0] % world_size})." + ) dim_size[0] = dim_size[0] // world_size @@ -984,7 +1000,11 @@ def _all_gather_fp8( # Note: We cannot directly all-gather the transposed FP8 tensor, # so temporarily modify quantizer to avoid creating FP8 transpose. if not isinstance(inp, Float8TensorStorage): - assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + if not isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + raise TypeError( + "Expected quantizer to be Float8Quantizer or Float8CurrentScalingQuantizer " + f"when input is not Float8TensorStorage, but got {type(quantizer).__name__}." + ) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer # and then set it back to the original value after quantizing @@ -1234,10 +1254,18 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): """ shape = tensor.shape - assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave." + if len(shape) < 2: + raise ValueError( + f"Wrong number of dimensions for fixing interleave: got {len(shape)}, " + f"expected at least 2 (shape={shape})." + ) first_dim = shape[0] flattened_trailing = math.prod(shape[1:]) - assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + if first_dim % world_size != 0: + raise ValueError( + f"Wrong dimensions for fixing interleave: first_dim={first_dim} is not divisible " + f"by world_size={world_size} (remainder={first_dim % world_size})." + ) tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) tensor = tex.swap_first_dims(tensor, out=None) return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) @@ -1327,7 +1355,11 @@ def _all_gather_nvfp4( f"found {inp.__class__.__name__})" ) - assert in_shape is not None or in_shape_t is not None, "No data found." + if in_shape is None and in_shape_t is None: + raise ValueError( + "No data found: both in_shape and in_shape_t are None. " + "Input tensor must have rowwise or columnwise data." + ) world_size = get_distributed_world_size(process_group) @@ -1380,7 +1412,11 @@ def _all_gather_nvfp4( if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses - assert in_shape is not None, "Shape not found." + if in_shape is None: + raise RuntimeError( + "Shape not found: in_shape is None but rowwise_usage is True. " + "Input tensor must have rowwise data for NVFP4 rowwise gathering." + ) in_scale_inv = inp._rowwise_scale_inv out_scale_inv = out._rowwise_scale_inv flattened_in_shape0 = math.prod(in_shape[:-1]) @@ -1681,7 +1717,10 @@ def gather_along_first_dim( # MXFP8 case if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): - assert isinstance(quantizer, MXFP8Quantizer) + if not isinstance(quantizer, MXFP8Quantizer): + raise TypeError( + f"Expected MXFP8Quantizer for MXFP8 all-gather, but got {type(quantizer).__name__}." + ) return _all_gather_mxfp8( inp, process_group, @@ -1692,7 +1731,10 @@ def gather_along_first_dim( # NVFP4 case if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): - assert isinstance(quantizer, NVFP4Quantizer) + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError( + f"Expected NVFP4Quantizer for NVFP4 all-gather, but got {type(quantizer).__name__}." + ) return _all_gather_nvfp4( inp, process_group, @@ -1835,8 +1877,15 @@ def symmetric_all_reduce( - The second element is the async work handle if async_op=True, otherwise None. """ - assert async_op is False, "Async symmetric ops no supported yet" - assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch" + if async_op: + raise RuntimeError( + f"Async symmetric ops are not supported yet, but async_op={async_op!r} was passed." + ) + if not HAS_TORCH_SYMMETRIC: + raise RuntimeError( + "Could not import symmetric memory from torch. " + "Please ensure torch.distributed._symmetric_memory is available." + ) if get_distributed_world_size(tp_group) == 1: return inp, None @@ -1969,10 +2018,19 @@ def _fsdp_gather_tensors( *tensors: torch.Tensor, ): if fsdp_group is not None: - assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal." + if len(shapes) != len(tensors): + raise ValueError( + "Number of tensors and tensor shapes must be equal, " + f"but got {len(shapes)} shapes and {len(tensors)} tensors." + ) for s, t in zip(shapes, tensors): if isinstance(t, torch.Tensor): - assert s is not None, "Internal TE error." + if s is None: + raise RuntimeError( + "Internal TE error: shape is None for a non-None tensor in " + "post_optimizer_step_fwd_amax_reduction. " + f"Tensor type: {type(t).__name__}, tensor shape: {t.shape}." + ) targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] for target in targets: safely_set_viewless_tensor_data( @@ -2020,17 +2078,23 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: fsdp_root : torch.nn.Module FSDP-wrapped root module that may contain FSDP-wrapped TE modules. """ - assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." + if not isinstance(fsdp_root, FSDP): + raise TypeError(f"Root module must be FSDP-wrapped, but got {type(fsdp_root).__name__}.") # If the root module is a TE module, inject FSDP information into it if _is_te_module(fsdp_root.module): if hasattr(fsdp_root, "primary_weights_in_fp8"): - assert not fsdp_root.primary_weights_in_fp8, ( - "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.quantized_model_init(...) context." - ) + if fsdp_root.primary_weights_in_fp8: + raise RuntimeError( + "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " + "Please initialize your model without the te.quantized_model_init(...) context." + ) root_state = _get_module_fsdp_state(fsdp_root) - assert root_state is not None, "Root module does not have a valid _FSDPState." + if root_state is None: + raise RuntimeError( + f"Root module ({type(fsdp_root.module).__name__}) does not have a valid " + "_FSDPState. Ensure the module is properly wrapped with FSDP." + ) fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules @@ -2038,10 +2102,12 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: for state, fsdp_module in zip(fsdp_states, fsdp_modules): if _is_te_module(fsdp_module.module): if hasattr(fsdp_module.module, "primary_weights_in_fp8"): - assert not fsdp_module.module.primary_weights_in_fp8, ( - "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.quantized_model_init(...) context." - ) + if fsdp_module.module.primary_weights_in_fp8: + raise RuntimeError( + f"TE module '{type(fsdp_module.module).__name__}' with primary weights " + "in FP8 cannot be FSDP-wrapped. Please initialize your model without " + "the te.quantized_model_init(...) context." + ) fsdp_module.module.fast_setattr("fsdp_group", state.process_group) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index bae911b4e1..86b8a4acf4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -139,7 +139,7 @@ def _make_graphed_callables( # Check training/inference is_training = all(c.training for c in callables) if not is_training and any(c.training for c in callables): - assert False, ( + raise RuntimeError( "make_graphed_callables only supports when modules are all in training or all in" " inference mode." ) @@ -148,8 +148,16 @@ def _make_graphed_callables( _order_without_wgrad = None delay_wgrad_compute = False if _order is None: - assert len(sample_args) == len(callables) - assert len(sample_kwargs) == len(callables) + if len(sample_args) != len(callables): + raise ValueError( + "Expected sample_args to have the same length as callables, " + f"but got {len(sample_args)} sample_args for {len(callables)} callables" + ) + if len(sample_kwargs) != len(callables): + raise ValueError( + "Expected sample_kwargs to have the same length as callables, " + f"but got {len(sample_kwargs)} sample_kwargs for {len(callables)} callables" + ) else: # Custom logic for interleaved pipeline parallelism # Note: This is tightly coupled with the Megatron-core @@ -173,48 +181,62 @@ def _make_graphed_callables( _order_without_wgrad.append(c_id) num_model_chunks = max(_order_without_wgrad) num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2 - assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad) + if num_model_chunks * num_microbatches * 2 != len(_order_without_wgrad): + raise ValueError( + f"Pipeline-parallel order dimension mismatch: num_model_chunks ({num_model_chunks})" + f" * num_microbatches ({num_microbatches}) * 2 =" + f" {num_model_chunks * num_microbatches * 2}, but len(_order_without_wgrad) =" + f" {len(_order_without_wgrad)}" + ) # When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which # allows for fine-grained graph capture order. if delay_wgrad_compute: - assert ( - _num_layers_per_chunk is not None - ), "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True." + if _num_layers_per_chunk is None: + raise ValueError( + "'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True." + ) for num_layers in _num_layers_per_chunk: - assert ( - num_layers == 1 - ), "Each model chunk must have only one layer when delay_wgrad_compute is True." + if num_layers != 1: + raise ValueError( + "Each model chunk must have only one layer when delay_wgrad_compute is" + f" True, but got {num_layers} layers." + ) # Determine number of layers in each model chunk. if _num_layers_per_chunk is None: - assert len(sample_args) * 2 >= len(_order_without_wgrad) and ( - len(sample_args) * 2 % len(_order_without_wgrad) == 0 - ), ( - f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and {len(sample_args)} * 2" - f" % {len(_order_without_wgrad)} == 0" - ) + if not ( + len(sample_args) * 2 >= len(_order_without_wgrad) + and (len(sample_args) * 2 % len(_order_without_wgrad) == 0) + ): + raise ValueError( + f"{len(sample_args)} * 2 >= {len(_order_without_wgrad)} and" + f" {len(sample_args)} * 2 % {len(_order_without_wgrad)} == 0" + ) num_layers = len(sample_args) // num_model_chunks // num_microbatches _num_layers_per_chunk = [num_layers] * num_model_chunks else: - assert ( + if not ( isinstance(_num_layers_per_chunk, int) or len(_num_layers_per_chunk) == num_model_chunks - ), ( - "If _num_layers_per_chunk is provided, it must be an integer or a list of" - f" {num_model_chunks} integers, but got {_num_layers_per_chunk}." - ) + ): + raise ValueError( + "If _num_layers_per_chunk is provided, it must be an integer or a list of" + f" {num_model_chunks} integers, but got {_num_layers_per_chunk}." + ) if isinstance(_num_layers_per_chunk, int): _num_layers_per_chunk = [_num_layers_per_chunk] * num_model_chunks total_num_layers = sum(_num_layers_per_chunk) - assert len(callables) == total_num_layers, ( - f"Callables should have ({total_num_layers}) " - + f"entries when order input is provided but got {len(callables)}." - ) - assert len(sample_args) == total_num_layers * num_microbatches, ( - f"Expected {total_num_layers * num_microbatches} " - + f"args tuple, but got {len(sample_args)}." - ) + if len(callables) != total_num_layers: + raise ValueError( + f"Callables should have ({total_num_layers}) " + + f"entries when order input is provided but got {len(callables)}." + ) + if len(sample_args) != total_num_layers * num_microbatches: + raise ValueError( + f"Expected {total_num_layers * num_microbatches} " + + f"args tuple, but got {len(sample_args)}." + ) # Calculate the starting index of each chunk in callables for future use. _prefix_num_layers = [0] @@ -222,19 +244,26 @@ def _make_graphed_callables( num_layers = _num_layers_per_chunk[m_chunk] _prefix_num_layers.append(_prefix_num_layers[-1] + num_layers) - assert len(sample_kwargs) == len(sample_args) + if len(sample_kwargs) != len(sample_args): + raise ValueError( + "Pipeline-parallel schedule requires sample_kwargs and sample_args to have " + f"the same length, but got {len(sample_kwargs)} sample_kwargs " + f"for {len(sample_args)} sample_args" + ) # Check reuse graph conditions and reorganize sample_args and sample_kwargs. # Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers # when the graph is replayed. If two model chunk microbatches have no overlap between their # forward and backward, then we can reduce memory usage by reusing the same static buffers. if _reuse_graph_input_output_buffers: - assert ( - _order is not None - ), "`_order` must be provided when `_reuse_graph_input_output_buffers` is True." - assert ( - is_training - ), "`_reuse_graph_input_output_buffers` is only available in training mode." + if _order is None: + raise ValueError( + "`_order` must be provided when `_reuse_graph_input_output_buffers` is True." + ) + if not is_training: + raise RuntimeError( + "`_reuse_graph_input_output_buffers` is only available in training mode." + ) if isinstance(sample_args, tuple): sample_args = list(sample_args) if isinstance(sample_kwargs, tuple): @@ -300,20 +329,22 @@ def _make_graphed_callables( # Check callables for c in callables: if isinstance(c, torch.nn.Module): - assert ( + if not ( len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0 - ), ( - "Modules must not have hooks registered at the time they are passed. " - + "However, registering hooks on modules after passing them " - + "through make_graphed_callables is allowed." - ) - assert all(b.requires_grad is False for b in c.buffers()), ( - "In any :class:`~torch.nn.Module` passed to " - + ":func:`~make_graphed_callables`, only parameters may be trainable. " - + "All buffers must have ``requires_grad=False``." - ) + ): + raise RuntimeError( + "Modules must not have hooks registered at the time they are passed. " + + "However, registering hooks on modules after passing them " + + "through make_graphed_callables is allowed." + ) + if not all(b.requires_grad is False for b in c.buffers()): + raise RuntimeError( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. " + + "All buffers must have ``requires_grad=False``." + ) # Flatten callable arguments per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs] @@ -322,10 +353,11 @@ def _make_graphed_callables( flatten_arg, _ = _tree_flatten(args) flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys]) flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg)) - assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( - "In the beta API, sample_args " - + "for each callable must contain only Tensors. Other types are not allowed." - ) + if not all(isinstance(arg, torch.Tensor) for arg in flatten_arg): + raise TypeError( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly # passes to forward (ie, its sample_args) AND the module's parameter attributes. @@ -354,7 +386,12 @@ def _make_graphed_callables( ) else () ) - assert len(per_callable_module_params) == len(flatten_sample_args) + if len(per_callable_module_params) != len(flatten_sample_args): + raise ValueError( + "Pipeline-parallel dimension mismatch: " + f"per_callable_module_params has {len(per_callable_module_params)} entries, " + f"but flatten_sample_args has {len(flatten_sample_args)} entries" + ) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(flatten_sample_args)) @@ -400,12 +437,12 @@ def _make_graphed_callables( warmup_func_idx.append(func_idx) warmup_func.append(func) fwd_idx[m_chunk] += 1 - assert len(warmup_func) == len( - sample_args - ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." - assert len(warmup_func_idx) == len( - set(warmup_func_idx) - ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + if len(warmup_func) != len(sample_args): + raise ValueError(f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}.") + if len(warmup_func_idx) != len(set(warmup_func_idx)): + raise RuntimeError( + f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + ) # Filter the TE modules that cudagraph can access. visited_te_modules = {} @@ -429,9 +466,10 @@ def hook_fn( modules.add(module) # If forward is called on a te.ops.Sequential it is not called on its constituent ops elif isinstance(module, Sequential): - assert ( - module._module_groups is not None - ), "Should have been initialized by warmup" + if module._module_groups is None: + raise RuntimeError( + "module._module_groups should have been initialized by warmup" + ) for module_group in module._module_groups: if isinstance(module_group, OperationFuser): for basic_op in module_group._basic_ops: @@ -480,20 +518,22 @@ def hook_fn( grad_inputs[grad_inputs_idx] is None and grad_inputs_idx < num_required_grad_sample_args ): - assert allow_unused_input, ( - "The input tensor requires grad, but the grad is None after" - " backward pass." - ) + if not allow_unused_input: + raise RuntimeError( + "The input tensor requires grad, but the grad is None after" + " backward pass." + ) elif ( grad_inputs[grad_inputs_idx] is not None and grad_inputs_idx >= num_required_grad_sample_args ): module_params_with_grad.append(static_input_surface[inputs_idx]) if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): - assert warmup_iter == 0, ( - "no-grad params should only be used as inputs in the first warmup" - " iteration" - ) + if warmup_iter != 0: + raise RuntimeError( + "no-grad params should only be used as inputs in the first warmup" + f" iteration, but found in iteration {warmup_iter}" + ) per_callable_module_params[func_idx] = tuple(module_params_with_grad) static_input_surface = flatten_sample_args[func_idx] + tuple( module_params_with_grad @@ -531,7 +571,10 @@ def hook_fn( previous_chunk_last_callable_bwd_idx = None for i, c_id in enumerate(_order): if c_id > 0: - assert isinstance(c_id, int), "Forward order value must be an integer." + if not isinstance(c_id, int): + raise TypeError( + f"Forward order value must be an integer, but got {type(c_id).__name__}." + ) # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] m_chunk = c_id - 1 for l_no in range(_num_layers_per_chunk[m_chunk]): @@ -583,23 +626,27 @@ def hook_fn( break if wgrad_validation_list[i] is None: wgrad_validation_list[i] = False - assert wgrad_validation_list[i], ( - f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " - f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." - ) + if not wgrad_validation_list[i]: + raise RuntimeError( + f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " + f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." + ) elif ceil(c_id) != c_id: per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] - assert is_training, "Only training mode supports backward_dw." + if not is_training: + raise RuntimeError("Only training mode supports backward_dw.") # If no one module needs the backward_dw, the bwd_dw_graph will be empty. # So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate # the specific order of backward_dw. - assert ceil(c_id) - c_id == 0.5, ( - "The order diff of wgrad and dgrad must be 0.5, " - f"get {ceil(c_id) - c_id}." - ) - assert need_bwd_dw_graph[ - per_callable_bwd_idx - ], "No module needs wgrad computation but get float in order" + if ceil(c_id) - c_id != 0.5: + raise ValueError( + "The order diff of wgrad and dgrad must be 0.5, " + f"get {ceil(c_id) - c_id}." + ) + if not need_bwd_dw_graph[per_callable_bwd_idx]: + raise RuntimeError( + "No module needs wgrad computation but get float in order" + ) bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] with _graph_context_wrapper(bwd_dw_graph, pool=mempool): for module in visited_te_modules[per_callable_bwd_idx]: @@ -811,7 +858,11 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i torch.cuda.current_stream().wait_stream(cuda_graph_stream) else: fwd_graph.replay() - assert isinstance(static_outputs, tuple) + if not isinstance(static_outputs, tuple): + raise TypeError( + "Expected static_outputs to be a tuple, but got" + f" {type(static_outputs).__name__}" + ) return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @@ -820,7 +871,12 @@ def backward(ctx, *grads): # pylint: disable=missing-function-docstring # Replay backward graph - assert len(grads) == len(static_grad_outputs) + if len(grads) != len(static_grad_outputs): + raise ValueError( + "Backward graph grad dimension mismatch: " + f"received {len(grads)} grads, " + f"but expected {len(static_grad_outputs)} static_grad_outputs" + ) for g, grad in zip(static_grad_outputs, grads): if g is not None: # don't copy if autograd gods have been kind and the @@ -843,7 +899,11 @@ def backward(ctx, *grads): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Input args that didn't require grad expect a None gradient. - assert isinstance(static_grad_inputs, tuple) + if not isinstance(static_grad_inputs, tuple): + raise TypeError( + "Expected static_grad_inputs to be a tuple, but got" + f" {type(static_grad_inputs).__name__}" + ) return (None, None, None) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) @@ -853,9 +913,13 @@ def functionalized(*user_args, **user_kwargs): # Decide whether to update FP8 weights skip_fp8_weight_update = None if cache_quantized_params: - assert "is_first_microbatch" in user_kwargs and isinstance( + if "is_first_microbatch" not in user_kwargs or not isinstance( user_kwargs["is_first_microbatch"], bool - ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + ): + raise ValueError( + "`is_first_microbatch` boolean kwarg must be provided for FP8 weight" + " caching." + ) skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] @@ -1237,12 +1301,16 @@ def make_graphed_callables( modules = (modules,) if not isinstance(enabled, tuple): - assert isinstance(enabled, bool), "enabled must be a bool or a tuple of bools" + if not isinstance(enabled, bool): + raise TypeError( + f"enabled must be a bool or a tuple of bools, but got {type(enabled).__name__}" + ) enabled = (enabled,) * len(modules) else: - assert len(enabled) == len( - modules - ), f"enabled length ({len(enabled)}) must match modules length ({len(modules)})" + if len(enabled) != len(modules): + raise ValueError( + f"enabled length ({len(enabled)}) must match modules length ({len(modules)})" + ) if any(enabled) and recipe is None: recipe = get_default_fp8_recipe() elif not any(enabled): @@ -1278,7 +1346,8 @@ def call_func(self, *args, **kwargs): forward_funcs = [] for module in modules: - assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported." + if not isinstance(module, torch.nn.Module): + raise TypeError(f"Graphing for {type(module)} is not supported.") wrap_autocast(module) forward_funcs.append(module) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9c21141a39..2d4583e936 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -80,7 +80,8 @@ class UserBufferQuantizationMode(Enum): def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: """Returns a dummy tensor of given shape.""" - assert len(shape) == 2 + if len(shape) != 2: + raise ValueError(f"Expected 2D shape, got {len(shape)}D: {shape}") global _dummy_wgrads if (shape[0], shape[1], dtype) not in _dummy_wgrads: _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( @@ -156,10 +157,11 @@ def initialize_ub( which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. """ if not tex.device_supports_multicast(): - assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( - "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " - + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." - ) + if not bool(int(os.getenv("UB_SKIPMC", "0"))): + raise RuntimeError( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap " + "with CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + ) if not quantization_modes: warnings.warn( @@ -171,34 +173,48 @@ def initialize_ub( UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE ] else: - assert isinstance(quantization_modes, list), "quantization_modes must be a list" - assert all( - isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes - ), "quantization_modes must be a list of UserBufferQuantizationMode" + if not isinstance(quantization_modes, list): + raise TypeError( + f"quantization_modes must be a list, got {type(quantization_modes).__name__}" + ) + invalid_modes = [ + mode for mode in quantization_modes if not isinstance(mode, UserBufferQuantizationMode) + ] + if invalid_modes: + raise TypeError( + "quantization_modes must be a list of UserBufferQuantizationMode, " + f"got invalid entries: {invalid_modes}" + ) if isinstance(ub_cfgs, dict) or ub_cfgs is None: ub_cfgs = [ub_cfgs] * len(quantization_modes) else: - assert len(ub_cfgs) == len( - quantization_modes - ), "Number of ub_cfgs settings must match number of quantization configurations" + if len(ub_cfgs) != len(quantization_modes): + raise ValueError( + f"Number of ub_cfgs settings ({len(ub_cfgs)}) must match number of " + f"quantization configurations ({len(quantization_modes)})" + ) global _ub_communicators - assert _ub_communicators is None, "UB communicators are already initialized." + if _ub_communicators is not None: + raise RuntimeError("UB communicators are already initialized.") _ub_communicators = {} if tex.ubuf_built_with_mpi(): # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force # an MPI_Init() here by creating a new MPI process group... - assert torch.distributed.is_mpi_available() + if not torch.distributed.is_mpi_available(): + raise RuntimeError( + "MPI backend is not available in torch.distributed but is required " + "when Userbuffers is built with MPI support" + ) _ = torch.distributed.new_group(backend="mpi") helper = tex.CommOverlapHelper() else: # Bootstrapping with torch.distributed API, so check backend and construct # intra/inter-node process groups... - assert ( - torch.distributed.is_initialized() - ), "torch.distributed must be initialized before Userbuffers" + if not torch.distributed.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before using Userbuffers") if bootstrap_backend is None: bootstrap_backend = "nccl" if torch.distributed.is_mpi_available(): @@ -206,15 +222,16 @@ def initialize_ub( elif torch.distributed.is_gloo_available(): bootstrap_backend = "gloo" else: - assert bootstrap_backend in [ - "gloo", - "mpi", - "nccl", - ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" - assert torch.distributed.is_backend_available(bootstrap_backend), ( - f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " - f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." - ) + if bootstrap_backend not in ["gloo", "mpi", "nccl"]: + raise ValueError( + f"Invalid torch.distributed backend '{bootstrap_backend}' for bootstrapping " + "Userbuffers. Must be one of: 'gloo', 'mpi', 'nccl'" + ) + if not torch.distributed.is_backend_available(bootstrap_backend): + raise RuntimeError( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) @@ -333,9 +350,11 @@ def add_ub( warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) - assert ( - quantization_mode == UserBufferQuantizationMode.FP8 - ), "Atomic GEMM overlap supported only for FP8 GEMM." + if quantization_mode != UserBufferQuantizationMode.FP8: + raise ValueError( + "Atomic GEMM overlap supported only for FP8 GEMM, " + f"got quantization_mode={quantization_mode}" + ) if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." @@ -360,20 +379,24 @@ def add_ub( "for functionality." ) if name in layers_atomic_ring_exchange: - assert atomic_gemm and method == "ring_exchange", assert_message + if not (atomic_gemm and method == "ring_exchange"): + raise ValueError(assert_message) else: if atomic_gemm and method == "ring_exchange": - assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + if rs_ag_pairs[name] not in layers_atomic_ring_exchange: + raise ValueError(assert_message) if name in external_gemm_to_overlap: - assert method == "external", ( - f"At {name}, `external` overlap method is specified, but the selected method is" - f" {method}" - ) - assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( - f"At {name}, `external` overlap method is specified, but the external gemm" - f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" - ) + if method != "external": + raise ValueError( + f"At {name}, `external` overlap method is specified, but the selected method " + f"is {method}" + ) + if external_gemm_to_overlap[name] not in methods["ring_exchange"]: + raise ValueError( + f"At {name}, `external` overlap method is specified, but the external gemm " + f"{external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" + ) buffer_dtype = ( torch.uint8 @@ -424,7 +447,12 @@ def add_ub( and user_ub_cfg[name]["method"] != "bulk" ): wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in user_ub_cfg + if wgrad_name in user_ub_cfg: + raise ValueError( + f"Cannot specify user UB config for '{wgrad_name}' when its " + f"corresponding dgrad '{name}' uses a non-bulk overlap method " + f"('{user_ub_cfg[name]['method']}')" + ) layers_reduce_scatter_overlap.remove(wgrad_name) layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) @@ -451,8 +479,10 @@ def get_ub(name: str, use_fp8: bool): # So favour simplicity until the correct design becomes clear. # This is mainly an internal API so we don't need to worry about future changes key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) - assert _ub_communicators is not None, "UB manager is not initialized." - assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + if _ub_communicators is None: + raise RuntimeError("UB manager is not initialized.") + if key not in _ub_communicators: + raise KeyError(f"UB for {name} with use_fp8={use_fp8} is not registered.") return _ub_communicators[key] @@ -608,7 +638,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self, name: Optional[str] = None) -> None: super().__init__() - assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + if not torch.cuda.is_available(): + raise RuntimeError("TransformerEngine needs CUDA.") self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False @@ -694,9 +725,12 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): if buffer_key in FP8GlobalStateManager.global_amax_buffer: - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." + if buffer_key not in FP8GlobalStateManager.global_amax_history_buffer: + raise RuntimeError( + "TE internal error during amax history change: " + f"buffer_key '{buffer_key}' found in global_amax_buffer " + "but missing from global_amax_history_buffer" + ) FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ meta_key ].amax_history[0] @@ -745,10 +779,11 @@ def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() weight_quantizers = self._get_weight_quantizers() - assert len(weight_tensors) == len(weight_quantizers), ( - f"Number of weight tensors ({len(weight_tensors)}) and quantizers " - f"({len(weight_quantizers)}) must match" - ) + if len(weight_tensors) != len(weight_quantizers): + raise ValueError( + f"Number of weight tensors ({len(weight_tensors)}) and quantizers " + f"({len(weight_quantizers)}) must match" + ) for weight, quantizer in zip(weight_tensors, weight_quantizers): if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_quantizer(quantizer) @@ -796,7 +831,11 @@ def reset(key): torch.zeros_like(self.fp8_meta[key].amax_history) ) else: - assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + if key not in fp8_meta_tensors: + raise KeyError( + f"Cannot reset fp8 tensors: key '{key}' not found in fp8_meta_tensors. " + f"Available keys: {list(fp8_meta_tensors.keys())}" + ) self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) @@ -937,10 +976,11 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) + if dtype != param.dtype: + raise TypeError( + "Data types for parameters must match when outside of autocasted " + f"region. Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -1045,10 +1085,17 @@ def prepare_forward( delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: - assert inp.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise RuntimeError( + f"TransformerEngine needs CUDA. Got input on device: {inp.device}" + ) if self.tp_size > 1: - assert self.tp_group_initialized, "TP group not initialized." + if not self.tp_group_initialized: + raise RuntimeError( + "Tensor parallel group not initialized. Call " + "set_tensor_parallel_group() before forward pass when tp_size > 1." + ) self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) @@ -1057,10 +1104,11 @@ def prepare_forward( delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe: if self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) + if not self.fp8_meta["recipe"].reduce_amax: + raise ValueError( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) if not FP8GlobalStateManager.fp8_graph_capturing(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..56479e882a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -646,9 +646,8 @@ def __init__( self.ub_name = ub_name self.save_original_input = save_original_input self.single_grouped_parameter = single_grouped_parameter - assert ( - not ub_overlap_rs and not ub_overlap_ag - ), "GroupedLinear doesn't support Userbuffer overlap." + if ub_overlap_rs or ub_overlap_ag: + raise ValueError("GroupedLinear doesn't support Userbuffer overlap.") self.init_method = init_method self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name @@ -683,9 +682,11 @@ def __init__( ) self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" + if self.parallel_mode not in GemmParallelModes: + raise ValueError( + f"parallel_mode {parallel_mode!r} not supported." + f" Supported modes: {GemmParallelModes}" + ) if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) @@ -788,9 +789,11 @@ def make_grouped_weights(self, defer_init=False) -> None: # Re-register as a single grouped weight parameter. # Re-register as a single grouped weight parameter. - assert isinstance(grouped_weights, torch.Tensor) and ( - weight_quantizers[0] is None or not weight_quantizers[0].internal - ), "Found internal quantizer with `single_grouped_parameter=True`." + if not ( + isinstance(grouped_weights, torch.Tensor) + and (weight_quantizers[0] is None or not weight_quantizers[0].internal) + ): + raise RuntimeError("Found internal quantizer with `single_grouped_parameter=True`.") self.register_parameter( "weight", torch.nn.Parameter(grouped_weights), @@ -875,10 +878,13 @@ def forward( """ debug = self.is_debug_iter() - assert not isinstance( - inp, QuantizedTensorStorage - ), "GroupedLinear doesn't support input tensor in FP8." - assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." + if isinstance(inp, QuantizedTensorStorage): + raise TypeError("GroupedLinear doesn't support input tensor in FP8.") + if len(m_splits) != self.num_gemms: + raise ValueError( + f"Number of splits ({len(m_splits)}) should match number of" + f" GEMMs ({self.num_gemms})." + ) is_grad_enabled = torch.is_grad_enabled() @@ -969,10 +975,11 @@ def backward_dw(self): def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" - assert not self.tp_size > 1, ( - "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " - "Because the TP communication is handled outside of this module." - ) + if self.tp_size > 1: + raise ValueError( + "GroupedLinear doesn't support TP > 1 with Float8 current scaling. " + "Because the TP communication is handled outside of this module." + ) if fwd: for i in range(self.num_gemms): @@ -1077,7 +1084,8 @@ def _get_quantizers(self): def _get_debug_quantizers(self): original_quantizers = self._get_quantizers() - assert TEDebugState.debug_enabled + if not TEDebugState.debug_enabled: + raise RuntimeError("TEDebugState.debug_enabled must be True to get debug quantizers") names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 5beeed1262..ca59a0ebf8 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -42,10 +42,16 @@ def forward( return inp, torch.tensor([], device=inp.device) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert index.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not index.is_cuda: + raise ValueError(f"index must be a CUDA tensor, but got tensor on {index.device}.") # Shape check - assert inp.size(0) == index.size(0), "Permute not possible" + if inp.size(0) != index.size(0): + raise ValueError( + f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " + f"index.size(0) ({index.size(0)})." + ) # Data type check dtype = TE_DType[inp.dtype] @@ -119,7 +125,8 @@ def forward( # None probs check if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") if probs.dtype != torch.float32: warnings.warn( @@ -136,8 +143,12 @@ def forward( probs = torch.empty(0) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not row_id_map.is_cuda: + raise ValueError( + f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + ) # Data type check dtype = TE_DType[inp.dtype] @@ -198,19 +209,30 @@ def forward( ctx.probs = probs return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not routing_map.is_cuda: + raise ValueError( + f"routing_map must be a CUDA tensor, but got tensor on {routing_map.device}." + ) if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if not pad_offsets.is_cuda: + raise ValueError( + f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + ) - assert inp.size(0) == routing_map.size(0), "Permute not possible" + if inp.size(0) != routing_map.size(0): + raise ValueError( + f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " + f"routing_map.size(0) ({routing_map.size(0)})." + ) num_tokens, hidden_size = inp.size() num_experts = routing_map.size(1) - assert ( - num_out_tokens is not None - ), "num_out_tokens must be provided to the fused permute function." + if num_out_tokens is None: + raise ValueError("num_out_tokens must be provided to the fused permute function.") row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) @@ -226,13 +248,25 @@ def forward( if blockwise_recipe: fp8_scale = inp._rowwise_scale_inv.T.contiguous() scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Input shape: ({num_tokens}, {hidden_size}), " + f"scale shape: {tuple(fp8_scale.shape)}." + ) inp = inp._rowwise_data # mxfp8 scaling elif mxfp8_recipe: fp8_scale = inp._rowwise_scale_inv.contiguous() scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Input shape: ({num_tokens}, {hidden_size}), " + f"scale shape: {tuple(fp8_scale.shape)}." + ) inp = inp._rowwise_data # per-tensor scaling elif per_tensor_recipe: @@ -318,9 +352,11 @@ def backward( probs_grad = None if ctx.needs_input_grad[0]: row_id_map, pad_offsets = ctx.saved_tensors - assert not isinstance( - permuted_act_grad, QuantizedTensor - ), "The backward of moe_permute does not support FP8." + if isinstance(permuted_act_grad, QuantizedTensor): + raise TypeError( + "The backward of moe_permute does not support FP8, but got " + f"QuantizedTensor of type {type(permuted_act_grad).__name__}." + ) act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, @@ -360,17 +396,30 @@ def forward( with_probs = merging_probs is not None if with_probs: - assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + if not merging_probs.is_cuda: + raise ValueError( + "merging_probs must be a CUDA tensor, but got tensor on " + f"{merging_probs.device}." + ) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not row_id_map.is_cuda: + raise ValueError( + f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + ) if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if not pad_offsets.is_cuda: + raise ValueError( + f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + ) - assert not isinstance( - inp, QuantizedTensor - ), "The forward of moe_unpermute does not support FP8." + if isinstance(inp, QuantizedTensor): + raise TypeError( + "The forward of moe_unpermute does not support FP8, but got " + f"QuantizedTensor of type {type(inp).__name__}." + ) unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, @@ -427,13 +476,23 @@ def backward(ctx, unpermuted_act_grad): fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if ctx.num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({ctx.num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Scale shape: {tuple(fp8_scale.shape)}." + ) # mxfp8 scaling elif mxfp8_recipe: fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if ctx.num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({ctx.num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Scale shape: {tuple(fp8_scale.shape)}." + ) else: raise ValueError("Unsupported FP8 recipe") else: @@ -441,10 +500,13 @@ def backward(ctx, unpermuted_act_grad): fp8_dtype = None fp8_scale = None + permuted_scale = None if ctx.with_probs: - assert ( - not fp8 - ), "The backward of moe_unpermute with merging probs does not support FP8." + if fp8: + raise TypeError( + "The backward of moe_unpermute with merging probs does not support FP8, " + f"but got FP8 gradient with dtype {fp8_dtype}." + ) act_grad, probs_grad = ( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( unpermuted_act_grad, @@ -619,10 +681,12 @@ def moe_permute_and_pad_with_probs( align_size : int the alignment size for the input tensor. """ - assert ( - tokens_per_expert is not None - ), "tokens_per_expert must be provided to the fused permute padding function." - assert align_size > 0, f"align_size must be positive, got {align_size}" + if tokens_per_expert is None: + raise ValueError( + "tokens_per_expert must be provided to the fused permute padding function." + ) + if align_size <= 0: + raise ValueError(f"align_size must be positive, got {align_size}.") # Ensure tokens_per_expert is on the same device as input to avoid device transfers if tokens_per_expert.device != inp.device: @@ -713,15 +777,27 @@ def forward( if not inp.numel(): return inp, probs - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert split_sizes.is_cuda, "TransformerEngine needs CUDA." - assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") + if not split_sizes.is_cuda: + raise ValueError( + f"split_sizes must be a CUDA tensor, but got tensor on {split_sizes.device}." + ) + if not sorted_idxs.is_cuda: + raise ValueError( + f"sorted_idxs must be a CUDA tensor, but got tensor on {sorted_idxs.device}." + ) if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) - assert num_splits == sorted_idxs.size(0) + if num_splits != sorted_idxs.size(0): + raise ValueError( + f"split_sizes.size(0) ({num_splits}) must match " + f"sorted_idxs.size(0) ({sorted_idxs.size(0)})." + ) fp8 = isinstance(inp, Float8Tensor) if fp8: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..47e6d5c8dc 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -97,7 +97,8 @@ def check_recipe_support(recipe: Recipe) -> None: recipe_supported, unsupported_reason = check_fp8_block_scaling_support() elif isinstance(recipe, MXFP8BlockScaling): recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason + if not recipe_supported: + raise RuntimeError(unsupported_reason) def get_default_fp8_recipe() -> Recipe: diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 685b2c5548..22a6a41eb1 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -203,7 +203,8 @@ def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: # Parameter construction calls detach()/alias-like paths. if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): src = args[0] - assert isinstance(src, GroupedTensor) + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") if func == torch.ops.aten.detach.default: return make_wrapper_like(src, requires_grad=False) return make_wrapper_like(src, requires_grad=src.requires_grad) @@ -212,7 +213,8 @@ def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: # Handle this explicitly so grouped parameters can be created safely. if func == torch.ops.aten.expand.default: src = args[0] - assert isinstance(src, GroupedTensor) + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") expanded_shape = tuple(args[1]) src_shape = tuple(src.shape) if len(expanded_shape) == len(src_shape): @@ -228,7 +230,8 @@ def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: if func == torch.ops.aten.expand_as.default: src = args[0] other = args[1] - assert isinstance(src, GroupedTensor) + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") if other is src: return _GroupedIdentityFunc.apply(src) if tuple(other.shape) == tuple(src.shape): @@ -240,7 +243,8 @@ def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: # returning a flat view of grouped backing storage. if func in (torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default): src = args[0] - assert isinstance(src, GroupedTensor) + if not isinstance(src, GroupedTensor): + raise TypeError(f"Expected GroupedTensor, got {type(src).__name__}") target_shape = tuple(args[1]) if target_shape in ((-1,), (src.numel(),)): if src.rowwise_data is not None: @@ -317,7 +321,11 @@ def maybe_update_inplace(arg, new_arg, schema_arg): for arg, new_arg, schema_arg in zip(args, new_args, schema_args): maybe_update_inplace(arg, new_arg, schema_arg) for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): - assert kwarg == new_kwarg == schema_arg.name, "name of kwarg should match schema" + if kwarg != new_kwarg or kwarg != schema_arg.name: + raise RuntimeError( + f"Name of kwarg should match schema, got kwarg={kwarg!r}," + f" new_kwarg={new_kwarg!r}, schema_arg.name={schema_arg.name!r}" + ) maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) return None diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e7509f3994..3e8bf3f2f3 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -327,11 +327,14 @@ def update_usage( # If both rowwise and columnwise are requested, create columnwise from rowwise if needed if rowwise_usage and columnwise_usage: - assert ( - self._rowwise_data is not None - and self._rowwise_scale_inv is not None - and self._amax_rowwise is not None - ), "Cannot update to rowwise and columnwise usage because rowwise data is None." + if ( + self._rowwise_data is None + or self._rowwise_scale_inv is None + or self._amax_rowwise is None + ): + raise RuntimeError( + "Cannot update to rowwise and columnwise usage because rowwise data is None." + ) if self._columnwise_data is None or self._columnwise_scale_inv is None: self._create_columnwise() return @@ -381,16 +384,16 @@ def _create_columnwise(self): """ Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling. """ - assert ( - self._quantizer is not None and self._quantizer.with_2d_quantization - ), "Cannot create columnwise data without 2D quantization enabled." + if self._quantizer is None or not self._quantizer.with_2d_quantization: + raise RuntimeError("Cannot create columnwise data without 2D quantization enabled.") rowwise_data = self._rowwise_data if not rowwise_data.is_contiguous(): rowwise_data = rowwise_data.contiguous() # NVFP4 requires a specialized transpose that handles nibble repacking self._columnwise_data = tex.nvfp4_data_transpose(rowwise_data, out=self._columnwise_data) if self._columnwise_scale_inv is None: - assert self._quantizer is not None + if self._quantizer is None: + raise RuntimeError("Cannot create columnwise scale inverse: quantizer is None.") # Use logical shape (self.size()), not packed byte shape (rowwise_data.shape) # NVFP4 packs 2 elements per byte, so rowwise_data.shape[-1] is K/2 logical_shape = self.size() @@ -400,8 +403,18 @@ def _create_columnwise(self): dtype=self._rowwise_scale_inv.dtype, device=self._rowwise_scale_inv.device, ) - assert len(self._rowwise_scale_inv.shape) == 2 - assert len(self._columnwise_scale_inv.shape) == 2 + if len(self._rowwise_scale_inv.shape) != 2: + raise ValueError( + "Expected rowwise_scale_inv to be 2D, but got" + f" {len(self._rowwise_scale_inv.shape)}D with shape" + f" {self._rowwise_scale_inv.shape}." + ) + if len(self._columnwise_scale_inv.shape) != 2: + raise ValueError( + "Expected columnwise_scale_inv to be 2D, but got" + f" {len(self._columnwise_scale_inv.shape)}D with shape" + f" {self._columnwise_scale_inv.shape}." + ) # rowwise_scale_inv has shape [M_padded, K_tiles] where each tile's scale # is repeated 16 times (once per row in the 16x16 tile). diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index d23892af94..c80bc8aaa4 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -37,19 +37,31 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): """ if isinstance(tensor, Float8Tensor): old_raw_data = tensor._data - assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match" + if old_raw_data.dtype != new_raw_data.dtype: + raise ValueError( + "The data types of raw data don't match: " + f"old dtype={old_raw_data.dtype}, new dtype={new_raw_data.dtype}" + ) new_raw_data.detach().copy_(old_raw_data) tensor._data = new_raw_data del old_raw_data elif isinstance(tensor, Float8BlockwiseQTensor): old_raw_data = tensor._rowwise_data - assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match" + if old_raw_data.dtype != new_raw_data.dtype: + raise ValueError( + "The data types of raw data don't match: " + f"old dtype={old_raw_data.dtype}, new dtype={new_raw_data.dtype}" + ) new_raw_data.detach().copy_(old_raw_data) tensor._rowwise_data = new_raw_data del old_raw_data elif isinstance(tensor, NVFP4Tensor): old_rowwise = tensor._rowwise_data - assert old_rowwise.dtype == new_raw_data.dtype, "The data types of raw data don't match" + if old_rowwise.dtype != new_raw_data.dtype: + raise ValueError( + f"The data types of raw data don't match: {old_rowwise.dtype} vs" + f" {new_raw_data.dtype}" + ) new_raw_data.detach().copy_(old_rowwise) tensor._rowwise_data = new_raw_data del old_rowwise @@ -276,10 +288,16 @@ def _cast_master_weights_to_fp8_delayed_scaling( continue # If master weight is not None, start_offset must be a valid value. - assert start_offset is not None - assert start_offset >= 0 + if start_offset is None: + raise ValueError("start_offset must not be None when master_weight is provided") + if start_offset < 0: + raise ValueError(f"start_offset must be non-negative, got {start_offset}") end_offset = start_offset + master_weight.numel() - assert end_offset <= model_weight.numel() + if end_offset > model_weight.numel(): + raise ValueError( + f"end_offset ({end_offset}) exceeds model_weight numel ({model_weight.numel()}), " + f"start_offset={start_offset}, master_weight numel={master_weight.numel()}" + ) # master_weight may be smaller than model_weight because it could be distributed across # multiple ranks. So we need to create a dummy weight using the raw data from model_weight. @@ -363,9 +381,21 @@ def _cast_master_weights_to_fp8_current_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() - assert quantizer.dtype == fp8_dtype - assert quantizer.force_pow_2_scales == force_pow_2_scales - assert quantizer.amax_epsilon == amax_epsilon + if quantizer.dtype != fp8_dtype: + raise ValueError( + "All model weights must have the same fp8 dtype, " + f"expected {fp8_dtype} but got {quantizer.dtype}" + ) + if quantizer.force_pow_2_scales != force_pow_2_scales: + raise ValueError( + "All model weights must have the same force_pow_2_scales, " + f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" + ) + if quantizer.amax_epsilon != amax_epsilon: + raise ValueError( + "All model weights must have the same amax_epsilon, " + f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" + ) scales.append(quantizer.scale.view(1)) scale_invs.append(model_weight._scale_inv.view(1)) @@ -479,19 +509,47 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() - assert block_len == quantizer.block_len - assert fp8_dtype == quantizer.dtype - assert force_pow_2_scales == quantizer.force_pow_2_scales - assert amax_epsilon == quantizer.amax_epsilon + if block_len != quantizer.block_len: + raise ValueError( + "All model weights must have the same block_len, " + f"expected {block_len} but got {quantizer.block_len}" + ) + if fp8_dtype != quantizer.dtype: + raise ValueError( + "All model weights must have the same fp8 dtype, " + f"expected {fp8_dtype} but got {quantizer.dtype}" + ) + if force_pow_2_scales != quantizer.force_pow_2_scales: + raise ValueError( + "All model weights must have the same force_pow_2_scales, " + f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" + ) + if amax_epsilon != quantizer.amax_epsilon: + raise ValueError( + "All model weights must have the same amax_epsilon, " + f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" + ) scale_shape = quantizer.get_scale_shape(model_weight.shape, False) amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) scale = torch.empty(scale_shape, dtype=torch.float32, device=device) scale_inv = model_weight._rowwise_scale_inv - assert len(scale_shape) == 2 - assert len(scale_inv.shape) == 2 - assert scale_inv.shape[0] == scale_shape[0] - assert scale_inv.shape[1] == scale_shape[1] + if len(scale_shape) != 2: + raise ValueError(f"scale_shape must be 2D, got {len(scale_shape)}D shape {scale_shape}") + if len(scale_inv.shape) != 2: + raise ValueError( + f"scale_inv must be 2D, got {len(scale_inv.shape)}D shape {scale_inv.shape}" + ) + if scale_inv.shape[0] != scale_shape[0]: + raise ValueError( + f"scale_inv dim 0 mismatch: scale_inv.shape={scale_inv.shape}," + f" scale_shape={scale_shape}" + ) + if scale_inv.shape[1] != scale_shape[1]: + raise ValueError( + f"scale_inv dim 1 mismatch: scale_inv.shape={scale_inv.shape}," + f" scale_shape={scale_shape}" + ) amaxes.append(amax) scales.append(scale) @@ -499,7 +557,11 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for blockwise scaling, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.fp8_block_scaling_compute_partial_amax( master_weight, amax, h, w, start_offset, block_len @@ -550,7 +612,11 @@ def _cast_master_weights_to_fp8_blockwise_scaling( end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for blockwise scaling partial cast, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype @@ -581,9 +647,12 @@ def _cast_master_weights_to_nvfp4_2d( amax_targets: List[Optional[torch.Tensor]] = [] for model_weight, _, _, _ in params: quantizer = model_weight._get_quantizer() - assert isinstance(quantizer, NVFP4Quantizer) - assert quantizer.with_2d_quantization, "NVFP4 2D quantization must be enabled." - assert len(model_weight.shape) == 2 + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError(f"Expected NVFP4Quantizer, got {type(quantizer).__name__}") + if not quantizer.with_2d_quantization: + raise ValueError("NVFP4 2D quantization must be enabled.") + if len(model_weight.shape) != 2: + raise ValueError(f"Expected 2D model weight, got {len(model_weight.shape)}D") h, w = model_weight.shape tile_h = (h + block_len - 1) // block_len tile_w = (w + block_len - 1) // block_len @@ -616,13 +685,15 @@ def _cast_master_weights_to_nvfp4_2d( scale = packed_scales[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) global_amax_view = global_amax_views[i] - assert model_weight._rowwise_scale_inv is not None + if model_weight._rowwise_scale_inv is None: + raise RuntimeError("model_weight._rowwise_scale_inv must not be None") amaxes.append(amax) scales.append(scale) if master_weight is not None and master_weight.numel() > 0: - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError(f"Expected 2D model weight, got {len(model_weight.shape)}D") h, w = model_weight.shape # Collect for batched processing master_weight_list.append(master_weight) @@ -728,7 +799,8 @@ def _cast_master_weights_to_nvfp4_2d( byte_start = start_offset // 2 byte_end = (end_offset + 1) // 2 model_weight_fragment = rowwise_bytes[byte_start:byte_end] - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError(f"Expected 2D model weight, got {len(model_weight.shape)}D") h, w = model_weight.shape partial_cast_inp_list.append(master_weight) @@ -793,9 +865,15 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( cu_colwise_amax_sizes = [0] for model_weight, _, _, _ in params: rowwise_shape = model_weight._rowwise_scale_inv.shape - assert len(rowwise_shape) == 2 + if len(rowwise_shape) != 2: + raise ValueError( + f"rowwise_scale_inv must be 2D, got {len(rowwise_shape)}D shape {rowwise_shape}" + ) colwise_shape = model_weight._columnwise_scale_inv.shape - assert len(colwise_shape) == 2 + if len(colwise_shape) != 2: + raise ValueError( + f"columnwise_scale_inv must be 2D, got {len(colwise_shape)}D shape {colwise_shape}" + ) cu_rowwise_amax_sizes.append( cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1] ) @@ -834,7 +912,11 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for MXFP8 scaling, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.mxfp8_scaling_compute_partial_amax( master_weight, amax_rowwise, amax_colwise, h, w, start_offset @@ -878,7 +960,11 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( else: rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset] - assert len(model_weight.shape) == 2 + if len(model_weight.shape) != 2: + raise ValueError( + "model_weight must be 2D for MXFP8 scaling partial cast, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.mxfp8_scaling_partial_cast( master_weight, @@ -966,7 +1052,8 @@ def _nvfp4_2d_multi_tensor_transpose(nvfp4_tensors: List[NVFP4Tensor]): # Allocate columnwise_scale_inv if needed if tensor._columnwise_scale_inv is None: - assert tensor._quantizer is not None + if tensor._quantizer is None: + raise RuntimeError("tensor._quantizer must not be None") columnwise_scale_inv_shape = tensor._quantizer.get_scale_shape(logical_shape, True) columnwise_scale_inv = torch.empty( columnwise_scale_inv_shape, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 868cbbdac8..4b96ccf739 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -373,23 +373,35 @@ def __init__( self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm if parallel_attention_mlp: - assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'" - assert not self.apply_residual_connection_post_layernorm, ( - "parallel_attention and apply_residual_connection_post_layernorm " - "not supported simultaneously." - ) - assert ( - not self.output_layernorm - ), "parallel_attention and output_layernorm not supported simultaneously" + if self.layer_type != "encoder": + raise ValueError( + "parallel_attention requires layer_type='encoder', " + f"but got layer_type={self.layer_type!r}" + ) + if self.apply_residual_connection_post_layernorm: + raise ValueError( + "parallel_attention and apply_residual_connection_post_layernorm " + "are not supported simultaneously." + ) + if self.output_layernorm: + raise ValueError( + "parallel_attention and output_layernorm are not supported simultaneously." + ) self.parallel_attention_mlp = parallel_attention_mlp - assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" + if layer_type not in LayerTypes: + raise ValueError( + f"layer_type {layer_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in LayerTypes)}" + ) if not fuse_qkv_params: - assert ( - not fuse_wgrad_accumulation - ), "Gradient accumulation fusion requires single QKV parameter." + if fuse_wgrad_accumulation: + raise ValueError( + "Gradient accumulation fusion (fuse_wgrad_accumulation=True) " + "requires fuse_qkv_params=True, but fuse_qkv_params is False." + ) if not fuse_qkv_params: qkv_weight_interleaved = False @@ -796,32 +808,57 @@ def forward( }: enc_dec_bottom_right_diagonal = True - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" - assert ( - enc_dec_attn_mask_type in AttnMaskTypes - ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported" + if self_attn_mask_type not in AttnMaskTypes: + raise ValueError( + f"self_attn_mask_type {self_attn_mask_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in AttnMaskTypes)}" + ) + if enc_dec_attn_mask_type not in AttnMaskTypes: + raise ValueError( + f"enc_dec_attn_mask_type {enc_dec_attn_mask_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in AttnMaskTypes)}" + ) hidden_states = hidden_states.contiguous() if self.sequence_parallel and self.seq_length is not None: - assert ( - hidden_states.shape[0] == self.seq_length // self.tp_size - ), "Sequence dimension must be split across TP group when using sequence parallel." + if hidden_states.shape[0] != self.seq_length // self.tp_size: + raise ValueError( + "Sequence dimension must be split across TP group when using " + "sequence parallel. Expected hidden_states.shape[0] to be " + f"{self.seq_length // self.tp_size} " + f"(seq_length={self.seq_length} // tp_size={self.tp_size}), " + f"but got {hidden_states.shape[0]}." + ) if ( "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" ) and attention_mask is not None: - assert all( - attention_mask[i].dtype == torch.bool for i in range(len(attention_mask)) - ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors" + if not all(attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))): + non_bool_dtypes = [ + (i, attention_mask[i].dtype) + for i in range(len(attention_mask)) + if attention_mask[i].dtype != torch.bool + ] + raise TypeError( + "Attention mask must be a boolean tensor or a list/tuple of boolean " + f"tensors, but found non-bool dtypes at indices: {non_bool_dtypes}" + ) if ( "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" ) and enc_dec_attn_mask is not None: - assert all( + if not all( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) - ), "Encoder-decoder attention mask must be boolean tensor(s)" + ): + non_bool_dtypes = [ + (i, enc_dec_attn_mask[i].dtype) + for i in range(len(enc_dec_attn_mask)) + if enc_dec_attn_mask[i].dtype != torch.bool + ] + raise TypeError( + "Encoder-decoder attention mask must be boolean tensor(s), " + f"but found non-bool dtypes at indices: {non_bool_dtypes}" + ) # For AMP if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index b1cc3be19d..a23e822f91 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -146,7 +146,8 @@ def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None: def ensure_divisibility(numerator: int, denominator: int) -> None: """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" + if numerator % denominator != 0: + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator: int, denominator: int) -> int: @@ -270,13 +271,16 @@ def forward( @staticmethod def backward(ctx, *grad_outputs): # pylint: disable=missing-function-docstring - assert len(grad_outputs) > 0, "No gradients received for backprop!" + if len(grad_outputs) == 0: + raise RuntimeError("No gradients received for backprop!") if isinstance(ctx.split_size_or_sections, (list, tuple)): split_sizes = ctx.split_size_or_sections - assert len(grad_outputs) == len( - split_sizes - ), "Unequal number of gradients vs split sections for backprop!" + if len(grad_outputs) != len(split_sizes): + raise RuntimeError( + f"Unequal number of gradients ({len(grad_outputs)}) vs " + f"split sections ({len(split_sizes)}) for backprop!" + ) if isinstance(ctx.split_size_or_sections, int): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) @@ -370,7 +374,8 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: """Checks if passed in param function has everything required for tensor/model and sequence parallel. """ - assert callable(get_rng_tracker), "get_rng_tracker is not a valid function" + if not callable(get_rng_tracker): + raise TypeError(f"get_rng_tracker must be callable, got {type(get_rng_tracker).__name__}") rng_tracker = None try: @@ -378,15 +383,13 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: except Exception as e: raise RuntimeError("Cannot call get_rng_tracker function") from e - assert hasattr(rng_tracker, "get_states") and callable( - rng_tracker.get_states - ), "rng_tracker object does not have valid method get_states" - assert hasattr(rng_tracker, "set_states") and callable( - rng_tracker.set_states - ), "rng_tracker object does not have valid method set_states" - assert hasattr(rng_tracker, "fork") and callable( - rng_tracker.fork - ), "rng_tracker object does not have valid method fork" + for method_name in ("get_states", "set_states", "fork"): + if not hasattr(rng_tracker, method_name) or not callable(getattr(rng_tracker, method_name)): + raise TypeError( + f"rng_tracker object ({type(rng_tracker).__name__}) does not have " + f"a valid callable method '{method_name}'. " + "Required methods: get_states, set_states, fork." + ) validate_ctx_manager(rng_tracker.fork) @@ -397,11 +400,12 @@ def assert_viewless_tensor(tensor: torch.Tensor, extra_msg: Optional[str] = None return [assert_viewless_tensor(t) for t in tensor] if not isinstance(tensor, torch.Tensor): return tensor - assert tensor._base is None, ( - "Ensure tensor._base is None before setting tensor.data or storing " - "tensor to memory buffer. Otherwise, a memory leak will occur (and " - f"likely accumulate over iterations). {extra_msg}" - ) + if tensor._base is not None: + raise ValueError( + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " + f"likely accumulate over iterations). {extra_msg}" + ) return tensor @@ -439,11 +443,13 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" for tensor in tensors: - assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, ( - "FP8 execution requires the product of all dimensions except the last to be divisible" - " by 8 and the last dimension to be divisible by 16, but got tensor with" - f" dims={list(tensor.size())}" - ) + if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0: + raise ValueError( + "FP8 execution requires the product of all dimensions except the last to be" + " divisible by 8 and the last dimension to be divisible by 16, but got tensor" + f" with dims={list(tensor.size())} (product of leading dims =" + f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})" + ) def is_bf16_compatible() -> bool: @@ -741,7 +747,9 @@ def __cuda_array_interface__(self): def torch_dtype_to_np_typestr(self): """Convert PyTorch dtype to numpy typestr.""" ret = _torch_dtype_to_np_typestr_dict.get(self.dtype) - assert ret is not None, f"Unsupported dtype: {self.dtype}" + if ret is None: + supported = ", ".join(str(d) for d in _torch_dtype_to_np_typestr_dict) + raise TypeError(f"Unsupported dtype: {self.dtype}. Supported dtypes: {supported}") return ret @@ -780,4 +788,7 @@ def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torc return x if x is None: return None - raise TypeError(f"Invalid type {type(x)} to make weak ref") + raise TypeError( + f"Invalid type {type(x).__name__} to make weak ref. " + "Valid types are: torch.Tensor, tuple, list, dict, int, float, bool, and None." + )