From 3b34388ae6138473cad953fb17e00f7b4e537df6 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Thu, 9 Apr 2026 16:21:50 +0900 Subject: [PATCH 1/2] [DLight] Add CPU Reduction schedule rule for softmax-like operators Add a DLight schedule rule targeting CPU reduction patterns (softmax, layer norm, RMS norm) that previously had no CPU-specific schedule, causing LLVM auto-vectorization to produce suboptimal code. This addresses apache/tvm#18569 where RVV softmax is 1.34x slower than scalar due to: - Excessive loop unrolling (2345 -> 1193 LLVM IR lines, -49%) - Harmful fixed-width vector usage on scalable-vector targets - No parallelization of the batch axis The rule applies the following schedule: 1. Parallelize leading spatial axes (batch dimension) 2. Compute all blocks under the spatial loop for locality 3. Vectorize injective blocks (exp, norm) on inner axis 4. Split reduction inner axis to VLEN-sized chunks with unroll annotation to guide LLVM codegen Assembly instruction count comparison (shape=(14,185), fast_softmax): - RV scalar baseline: 1463 instructions - RVV unscheduled: 3282 instructions (2.2x bloat, the bug) - RVV with schedule: 1111 instructions (-66% vs unscheduled) Tested with softmax and fast_softmax across 7 shapes from (1,10) to (1,30522). --- python/tvm/s_tir/dlight/cpu/__init__.py | 1 + python/tvm/s_tir/dlight/cpu/reduction.py | 147 ++++++++++ .../python/s_tir/dlight/test_cpu_reduction.py | 270 ++++++++++++++++++ 3 files changed, 418 insertions(+) create mode 100644 python/tvm/s_tir/dlight/cpu/reduction.py create mode 100644 tests/python/s_tir/dlight/test_cpu_reduction.py diff --git a/python/tvm/s_tir/dlight/cpu/__init__.py b/python/tvm/s_tir/dlight/cpu/__init__.py index 8743c616bb10..20e1e9a3b829 100644 --- a/python/tvm/s_tir/dlight/cpu/__init__.py +++ b/python/tvm/s_tir/dlight/cpu/__init__.py @@ -20,3 +20,4 @@ """ from .gemv import GEMV +from .reduction import Reduction diff --git a/python/tvm/s_tir/dlight/cpu/reduction.py b/python/tvm/s_tir/dlight/cpu/reduction.py new file mode 100644 index 000000000000..60d660603b3c --- /dev/null +++ b/python/tvm/s_tir/dlight/cpu/reduction.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CPU reduction rule for operators including softmax, layer norm, RMS norm, etc.""" + +from tvm import s_tir, tirx +from tvm.target import Target +from tvm.target.codegen import llvm_get_vector_width + +from ..analysis import normalize_prim_func +from ..base import get_extent +from .base import CPUScheduleRule + + +class Reduction(CPUScheduleRule): + """CPU reduction rule for softmax, layer norm, RMS norm, and similar operators. + + Targets patterns with a mix of reduction (SR) and injective (SS) blocks, + where all blocks share the same leading spatial axes. + Example: softmax = maxelem(SR) -> exp(SS) -> expsum(SR) -> norm(SS). + + Schedule strategy: + 1. Parallelize leading spatial axes (batch dimension). + 2. Move all blocks under the spatial loop via compute_at. + 3. Vectorize injective blocks (exp, delta, norm) on their inner axis. + 4. Split reduction inner axis to VLEN-sized chunks and annotate for + LLVM unrolling, preventing harmful full-unroll by the backend. + + Note: vectorized reduction via rfactor is not used here because TVM's + rfactor primitive requires the reduction block to be the first child of + its enclosing loop, which is incompatible with compute_at when multiple + blocks share the same spatial loop. A follow-up using RVV reduction + intrinsics (vfredmax/vfredusum) via tensorize can address this. + """ + + def apply( # pylint: disable=too-many-locals,too-many-return-statements,too-many-branches + self, + func: tirx.PrimFunc, + target: Target, + _: bool, + ) -> None | s_tir.Schedule | list[s_tir.Schedule]: + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): + return None + + sch = s_tir.Schedule(func) + block_infos = normalize_prim_func(sch) + if block_infos is None or len(block_infos) < 2: + return None + + # Must have at least one reduction block and last block must be injective. + has_reduction = any(not bi.is_injective() for bi in block_infos) + if not has_reduction or not block_infos[-1].is_injective(): + return None + + # All blocks must have at least one leading spatial axis. + for bi in block_infos: + dk = bi.dom_kind() + if not dk or dk[0] != "S": + return None + + # Find the number of leading spatial axes (from the first reduction block). + first_reduction = next(bi for bi in block_infos if not bi.is_injective()) + dom_kind = first_reduction.dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + if num_leading_s == 0: + return None + + # Determine vector width from target. + try: + vlen_bits = llvm_get_vector_width(target) + except Exception: # pylint: disable=broad-except + vlen_bits = 128 + dtype_bits = 32 # default float32 + vec_lanes = max(vlen_bits // dtype_bits, 4) + + # --- Phase 1: Parallelize spatial on the last block --- + last_block = block_infos[-1] + loops = sch.get_loops(last_block.block_rv) + if num_leading_s > 1: + spatial = sch.fuse(*loops[:num_leading_s]) + else: + spatial = loops[0] + sch.parallel(spatial) + + # --- Phase 2: Vectorize the last (injective) block --- + inner_loops = sch.get_loops(last_block.block_rv) + if len(inner_loops) > 1: + inner = inner_loops[-1] + extent = get_extent(sch, inner) + if isinstance(extent, int) and extent > vec_lanes: + _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.vectorize(vec_loop) + elif isinstance(extent, int): + sch.vectorize(inner) + + # --- Phase 3: compute_at all preceding blocks under spatial --- + for block_info in reversed(block_infos[:-1]): + sch.compute_at(block_info.block_rv, spatial, preserve_unit_loops=True) + + # --- Phase 4: Vectorize injective, split+unroll reduction blocks --- + for block_info in block_infos[:-1]: + block = block_info.block_rv + block_loops = sch.get_loops(block) + if len(block_loops) <= 1: + continue + inner = block_loops[-1] + extent = get_extent(sch, inner) + + if block_info.is_injective(): + # Injective blocks (e.g. exp, delta): vectorize directly. + if isinstance(extent, int) and extent > vec_lanes: + _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.vectorize(vec_loop) + elif isinstance(extent, int) and extent >= 2: + sch.vectorize(inner) + else: + # Reduction blocks (e.g. max, sum): split inner to vec_lanes + # and annotate for unrolling. This prevents LLVM from doing + # harmful full-unroll of the 185-element loop and gives it + # a vec_lanes-sized inner loop to auto-vectorize. + if isinstance(extent, int) and extent > vec_lanes: + _, inner_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.annotate( + inner_loop, + ann_key="pragma_auto_unroll_max_step", + ann_val=vec_lanes, + ) + sch.annotate( + inner_loop, + ann_key="pragma_unroll_explicit", + ann_val=1, + ) + + return sch diff --git a/tests/python/s_tir/dlight/test_cpu_reduction.py b/tests/python/s_tir/dlight/test_cpu_reduction.py new file mode 100644 index 000000000000..db8280a61a0f --- /dev/null +++ b/tests/python/s_tir/dlight/test_cpu_reduction.py @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""Tests for CPU DLight Reduction schedule rule.""" + +import pytest + +import tvm +import tvm.testing +from tvm import te, tirx, topi +from tvm.s_tir import dlight as dl +from tvm.s_tir.dlight.cpu import Reduction +from tvm.target import Target + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _llvm_target(): + return Target({"kind": "llvm"}) + + +def _rvv_target(): + return Target( + { + "kind": "llvm", + "mtriple": "riscv64-linux-gnu", + "mcpu": "generic-rv64", + "mabi": "lp64d", + "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c", "+v"], + } + ) + + +def _build_softmax(batch, features, fast=False): + A = te.placeholder((batch, features), dtype="float32", name="A") + B = topi.nn.fast_softmax(A, axis=1) if fast else topi.nn.softmax(A, axis=1) + func = te.create_prim_func([A, B]) + return tvm.IRModule({"main": func}) + + +def _apply_and_check(mod, target): + """Apply Reduction rule and verify it was applied.""" + rule = Reduction() + result = rule.apply(mod["main"], target, False) + assert result is not None, "Reduction rule should apply" + return result + + +# --------------------------------------------------------------------------- +# Test: schedule applicability +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fast", [False, True], ids=["softmax", "fast_softmax"]) +@pytest.mark.parametrize( + "batch,features", + [ + (1, 10), + (1, 128), + (14, 185), + (32, 256), + (64, 512), + (128, 1024), + (1, 30522), + ], +) +def test_reduction_applies(batch, features, fast): + """Reduction rule should apply to softmax/fast_softmax of various shapes.""" + mod = _build_softmax(batch, features, fast=fast) + target = _llvm_target() + _apply_and_check(mod, target) + + +# --------------------------------------------------------------------------- +# Test: scheduled TIR structure +# --------------------------------------------------------------------------- + + +def test_softmax_schedule_structure(): + """Verify the scheduled TIR has expected structure: + - parallel on batch axis + - vectorized innermost loops for injective blocks + - split+unroll for reduction blocks + """ + mod = _build_softmax(14, 185, fast=False) + target = _llvm_target() + sch = _apply_and_check(mod, target) + scheduled_mod = sch.mod + + # Check that tirx.is_scheduled is NOT set (only set by ApplyDefaultSchedule) + # but the schedule should be valid + assert scheduled_mod is not None + + # Verify via ApplyDefaultSchedule path + with target: + scheduled = dl.ApplyDefaultSchedule(Reduction())(mod) + func = scheduled["main"] + + # Check tirx.is_scheduled is set + assert func.attrs and func.attrs.get("tirx.is_scheduled", False) + + +def test_fast_softmax_schedule_structure(): + """fast_softmax should keep T_fast_exp as a separate vectorizable block.""" + mod = _build_softmax(14, 185, fast=True) + target = _llvm_target() + sch = _apply_and_check(mod, target) + script = str(sch.mod) + + # fast_exp block should exist (not inlined) + assert "T_fast_exp" in script or "T_softmax_delta" in script + # Should have T.parallel + assert "T.parallel" in script + # Should have T.vectorized + assert "T.vectorized" in script + + +# --------------------------------------------------------------------------- +# Test: LLVM IR quality (cross-compile to RISC-V RVV) +# --------------------------------------------------------------------------- + + +def _codegen_llvm_ir(mod, target): + """Lower and codegen to LLVM IR (no linking).""" + bound = tirx.transform.BindTarget(target.with_host(target))(mod) + pipeline = tirx.get_tir_pipeline("default") + lowered = pipeline(bound) + from tvm.tirx.build import split_host_device_mods + + host_mod, _ = split_host_device_mods(lowered) + host_mod = tirx.pipeline.finalize_host_passes()(host_mod) + built = tvm.target.codegen.build_module(host_mod, target) + return built.inspect_source("ll") + + +def _codegen_asm(mod, target): + """Lower and codegen to assembly (no linking).""" + bound = tirx.transform.BindTarget(target.with_host(target))(mod) + pipeline = tirx.get_tir_pipeline("default") + lowered = pipeline(bound) + from tvm.tirx.build import split_host_device_mods + + host_mod, _ = split_host_device_mods(lowered) + host_mod = tirx.pipeline.finalize_host_passes()(host_mod) + built = tvm.target.codegen.build_module(host_mod, target) + return built.inspect_source("s") + + +@pytest.mark.parametrize("fast", [False, True], ids=["softmax", "fast_softmax"]) +def test_rvv_code_size_reduction(fast): + """Scheduled RVV code should be smaller than unscheduled. + + The original issue (apache/tvm#18569) shows RVV softmax is 1.34x slower + than scalar, partly due to LLVM generating bloated code with excessive + unrolling. The schedule should reduce code size significantly. + """ + target = _rvv_target() + mod = _build_softmax(14, 185, fast=fast) + + # Unscheduled + ir_unsched = _codegen_llvm_ir(mod, target) + n_unsched = len(ir_unsched.splitlines()) + + # Scheduled + with target: + mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod) + ir_sched = _codegen_llvm_ir(mod_sched, target) + n_sched = len(ir_sched.splitlines()) + + # Scheduled should be meaningfully smaller (at least 30% reduction) + ratio = n_sched / n_unsched + assert ratio < 0.75, ( + f"Expected >=25% code reduction, got {(1 - ratio) * 100:.1f}% " + f"({n_unsched} -> {n_sched} lines)" + ) + + +def test_rvv_fast_softmax_vectorizes_exp(): + """fast_softmax + schedule should produce RVV vector instructions + for the polynomial exp approximation (no scalar exp calls).""" + target = _rvv_target() + mod = _build_softmax(14, 185, fast=True) + with target: + mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod) + ir = _codegen_llvm_ir(mod_sched, target) + + # Should have zero scalar exp calls (fast_exp uses polynomial) + scalar_exp = sum(1 for line in ir.splitlines() if "llvm.exp.f32" in line) + assert scalar_exp == 0, f"Expected 0 scalar exp calls, got {scalar_exp}" + + # Should have scalable vector operations + n_svec = ir.count(" 0, "Expected scalable vector operations in LLVM IR" + + +def test_rvv_asm_instruction_reduction(): + """Scheduled RVV assembly should have fewer total instructions + than both unscheduled RVV and scalar RV.""" + rvv = _rvv_target() + rv = Target( + { + "kind": "llvm", + "mtriple": "riscv64-linux-gnu", + "mcpu": "generic-rv64", + "mabi": "lp64d", + "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c"], + } + ) + + mod = _build_softmax(14, 185, fast=True) + + # Scalar baseline + asm_rv = _codegen_asm(mod, rv) + n_rv = len( + [ + line + for line in asm_rv.splitlines() + if line.strip() and not line.strip().startswith((".", "#", "/")) + ] + ) + + # RVV unscheduled + asm_rvv = _codegen_asm(mod, rvv) + n_rvv = len( + [ + line + for line in asm_rvv.splitlines() + if line.strip() and not line.strip().startswith((".", "#", "/")) + ] + ) + + # RVV scheduled + with rvv: + mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod) + asm_sched = _codegen_asm(mod_sched, rvv) + n_sched = len( + [ + line + for line in asm_sched.splitlines() + if line.strip() and not line.strip().startswith((".", "#", "/")) + ] + ) + + # Scheduled should be smaller than both unscheduled RVV and scalar + assert n_sched < n_rvv, ( + f"Scheduled ({n_sched}) should have fewer instructions than unscheduled RVV ({n_rvv})" + ) + assert n_sched <= n_rv * 1.1, ( + f"Scheduled ({n_sched}) should not be much larger than scalar RV ({n_rv})" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 4598374f327fcdf31a444bc32846479051864bb8 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Fri, 10 Apr 2026 15:52:42 +0900 Subject: [PATCH 2/2] [DLight] Infer dtype from buffer, support dynamic shapes, validate spatial axes - Infer dtype_bits from the last block's write buffer instead of hardcoding 32, so float16/bfloat16 get correct vector lane counts. - Support dynamic extents in split+vectorize by removing isinstance(extent, int) guards where TVM primitives handle them. - Replace broad except Exception around llvm_get_vector_width with return-value check (returns -1 on failure, not an exception). - Compute num_leading_s as min across ALL blocks, not just the first reduction block, ensuring compute_at safety. - Extract _vectorize_inner and _unroll_reduction_inner as static methods to reduce apply() complexity. --- python/tvm/s_tir/dlight/cpu/reduction.py | 114 ++++++++++++----------- 1 file changed, 60 insertions(+), 54 deletions(-) diff --git a/python/tvm/s_tir/dlight/cpu/reduction.py b/python/tvm/s_tir/dlight/cpu/reduction.py index 60d660603b3c..2e804f9537c8 100644 --- a/python/tvm/s_tir/dlight/cpu/reduction.py +++ b/python/tvm/s_tir/dlight/cpu/reduction.py @@ -16,7 +16,7 @@ # under the License. """CPU reduction rule for operators including softmax, layer norm, RMS norm, etc.""" -from tvm import s_tir, tirx +from tvm import DataType, s_tir, tirx from tvm.target import Target from tvm.target.codegen import llvm_get_vector_width @@ -25,6 +25,11 @@ from .base import CPUScheduleRule +def _get_num_leading_s(dom_kind: str) -> int: + """Count leading spatial ('S') axes in a dom_kind string.""" + return len(dom_kind) - len(dom_kind.lstrip("S")) + + class Reduction(CPUScheduleRule): """CPU reduction rule for softmax, layer norm, RMS norm, and similar operators. @@ -61,30 +66,34 @@ def apply( # pylint: disable=too-many-locals,too-many-return-statements,too-man return None # Must have at least one reduction block and last block must be injective. - has_reduction = any(not bi.is_injective() for bi in block_infos) - if not has_reduction or not block_infos[-1].is_injective(): + if not any(not bi.is_injective() for bi in block_infos): + return None + if not block_infos[-1].is_injective(): return None - # All blocks must have at least one leading spatial axis. + # Every block must start with at least one spatial axis, and all blocks + # must agree on the minimum number of leading spatial axes. + num_leading_s = None for bi in block_infos: dk = bi.dom_kind() if not dk or dk[0] != "S": return None - - # Find the number of leading spatial axes (from the first reduction block). - first_reduction = next(bi for bi in block_infos if not bi.is_injective()) - dom_kind = first_reduction.dom_kind() - num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) - if num_leading_s == 0: + n = _get_num_leading_s(dk) + num_leading_s = n if num_leading_s is None else min(num_leading_s, n) + if not num_leading_s: return None - # Determine vector width from target. - try: - vlen_bits = llvm_get_vector_width(target) - except Exception: # pylint: disable=broad-except + # Infer dtype from the last block's write buffer. + last_block_stmt = sch.get(block_infos[-1].block_rv) + dtype_bits = ( + DataType(last_block_stmt.writes[0].buffer.dtype).bits if last_block_stmt.writes else 32 + ) + + # Determine vector lanes from target VLEN. + vlen_bits = llvm_get_vector_width(target) + if vlen_bits <= 0: vlen_bits = 128 - dtype_bits = 32 # default float32 - vec_lanes = max(vlen_bits // dtype_bits, 4) + vec_lanes = max(vlen_bits // dtype_bits, 2) # --- Phase 1: Parallelize spatial on the last block --- last_block = block_infos[-1] @@ -96,15 +105,7 @@ def apply( # pylint: disable=too-many-locals,too-many-return-statements,too-man sch.parallel(spatial) # --- Phase 2: Vectorize the last (injective) block --- - inner_loops = sch.get_loops(last_block.block_rv) - if len(inner_loops) > 1: - inner = inner_loops[-1] - extent = get_extent(sch, inner) - if isinstance(extent, int) and extent > vec_lanes: - _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) - sch.vectorize(vec_loop) - elif isinstance(extent, int): - sch.vectorize(inner) + self._vectorize_inner(sch, last_block.block_rv, vec_lanes) # --- Phase 3: compute_at all preceding blocks under spatial --- for block_info in reversed(block_infos[:-1]): @@ -112,36 +113,41 @@ def apply( # pylint: disable=too-many-locals,too-many-return-statements,too-man # --- Phase 4: Vectorize injective, split+unroll reduction blocks --- for block_info in block_infos[:-1]: - block = block_info.block_rv - block_loops = sch.get_loops(block) - if len(block_loops) <= 1: - continue - inner = block_loops[-1] - extent = get_extent(sch, inner) - if block_info.is_injective(): - # Injective blocks (e.g. exp, delta): vectorize directly. - if isinstance(extent, int) and extent > vec_lanes: - _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) - sch.vectorize(vec_loop) - elif isinstance(extent, int) and extent >= 2: - sch.vectorize(inner) + self._vectorize_inner(sch, block_info.block_rv, vec_lanes) else: - # Reduction blocks (e.g. max, sum): split inner to vec_lanes - # and annotate for unrolling. This prevents LLVM from doing - # harmful full-unroll of the 185-element loop and gives it - # a vec_lanes-sized inner loop to auto-vectorize. - if isinstance(extent, int) and extent > vec_lanes: - _, inner_loop = sch.split(inner, factors=[None, vec_lanes]) - sch.annotate( - inner_loop, - ann_key="pragma_auto_unroll_max_step", - ann_val=vec_lanes, - ) - sch.annotate( - inner_loop, - ann_key="pragma_unroll_explicit", - ann_val=1, - ) + self._unroll_reduction_inner(sch, block_info.block_rv, vec_lanes) return sch + + @staticmethod + def _vectorize_inner(sch, block_rv, vec_lanes): + """Split the innermost loop to vec_lanes and vectorize.""" + block_loops = sch.get_loops(block_rv) + if len(block_loops) <= 1: + return + inner = block_loops[-1] + extent = get_extent(sch, inner) + if isinstance(extent, int): + if extent > vec_lanes: + _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.vectorize(vec_loop) + elif extent >= 2: + sch.vectorize(inner) + else: + _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.vectorize(vec_loop) + + @staticmethod + def _unroll_reduction_inner(sch, block_rv, vec_lanes): + """Split the reduction inner loop and annotate for unrolling.""" + block_loops = sch.get_loops(block_rv) + if len(block_loops) <= 1: + return + inner = block_loops[-1] + extent = get_extent(sch, inner) + if isinstance(extent, int) and extent <= vec_lanes: + return + _, inner_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.annotate(inner_loop, ann_key="pragma_auto_unroll_max_step", ann_val=vec_lanes) + sch.annotate(inner_loop, ann_key="pragma_unroll_explicit", ann_val=1)