diff --git a/python/tvm/s_tir/dlight/cpu/__init__.py b/python/tvm/s_tir/dlight/cpu/__init__.py index 8743c616bb10..20e1e9a3b829 100644 --- a/python/tvm/s_tir/dlight/cpu/__init__.py +++ b/python/tvm/s_tir/dlight/cpu/__init__.py @@ -20,3 +20,4 @@ """ from .gemv import GEMV +from .reduction import Reduction diff --git a/python/tvm/s_tir/dlight/cpu/reduction.py b/python/tvm/s_tir/dlight/cpu/reduction.py new file mode 100644 index 000000000000..2e804f9537c8 --- /dev/null +++ b/python/tvm/s_tir/dlight/cpu/reduction.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CPU reduction rule for operators including softmax, layer norm, RMS norm, etc.""" + +from tvm import DataType, s_tir, tirx +from tvm.target import Target +from tvm.target.codegen import llvm_get_vector_width + +from ..analysis import normalize_prim_func +from ..base import get_extent +from .base import CPUScheduleRule + + +def _get_num_leading_s(dom_kind: str) -> int: + """Count leading spatial ('S') axes in a dom_kind string.""" + return len(dom_kind) - len(dom_kind.lstrip("S")) + + +class Reduction(CPUScheduleRule): + """CPU reduction rule for softmax, layer norm, RMS norm, and similar operators. + + Targets patterns with a mix of reduction (SR) and injective (SS) blocks, + where all blocks share the same leading spatial axes. + Example: softmax = maxelem(SR) -> exp(SS) -> expsum(SR) -> norm(SS). + + Schedule strategy: + 1. Parallelize leading spatial axes (batch dimension). + 2. Move all blocks under the spatial loop via compute_at. + 3. Vectorize injective blocks (exp, delta, norm) on their inner axis. + 4. Split reduction inner axis to VLEN-sized chunks and annotate for + LLVM unrolling, preventing harmful full-unroll by the backend. + + Note: vectorized reduction via rfactor is not used here because TVM's + rfactor primitive requires the reduction block to be the first child of + its enclosing loop, which is incompatible with compute_at when multiple + blocks share the same spatial loop. A follow-up using RVV reduction + intrinsics (vfredmax/vfredusum) via tensorize can address this. + """ + + def apply( # pylint: disable=too-many-locals,too-many-return-statements,too-many-branches + self, + func: tirx.PrimFunc, + target: Target, + _: bool, + ) -> None | s_tir.Schedule | list[s_tir.Schedule]: + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): + return None + + sch = s_tir.Schedule(func) + block_infos = normalize_prim_func(sch) + if block_infos is None or len(block_infos) < 2: + return None + + # Must have at least one reduction block and last block must be injective. + if not any(not bi.is_injective() for bi in block_infos): + return None + if not block_infos[-1].is_injective(): + return None + + # Every block must start with at least one spatial axis, and all blocks + # must agree on the minimum number of leading spatial axes. + num_leading_s = None + for bi in block_infos: + dk = bi.dom_kind() + if not dk or dk[0] != "S": + return None + n = _get_num_leading_s(dk) + num_leading_s = n if num_leading_s is None else min(num_leading_s, n) + if not num_leading_s: + return None + + # Infer dtype from the last block's write buffer. + last_block_stmt = sch.get(block_infos[-1].block_rv) + dtype_bits = ( + DataType(last_block_stmt.writes[0].buffer.dtype).bits if last_block_stmt.writes else 32 + ) + + # Determine vector lanes from target VLEN. + vlen_bits = llvm_get_vector_width(target) + if vlen_bits <= 0: + vlen_bits = 128 + vec_lanes = max(vlen_bits // dtype_bits, 2) + + # --- Phase 1: Parallelize spatial on the last block --- + last_block = block_infos[-1] + loops = sch.get_loops(last_block.block_rv) + if num_leading_s > 1: + spatial = sch.fuse(*loops[:num_leading_s]) + else: + spatial = loops[0] + sch.parallel(spatial) + + # --- Phase 2: Vectorize the last (injective) block --- + self._vectorize_inner(sch, last_block.block_rv, vec_lanes) + + # --- Phase 3: compute_at all preceding blocks under spatial --- + for block_info in reversed(block_infos[:-1]): + sch.compute_at(block_info.block_rv, spatial, preserve_unit_loops=True) + + # --- Phase 4: Vectorize injective, split+unroll reduction blocks --- + for block_info in block_infos[:-1]: + if block_info.is_injective(): + self._vectorize_inner(sch, block_info.block_rv, vec_lanes) + else: + self._unroll_reduction_inner(sch, block_info.block_rv, vec_lanes) + + return sch + + @staticmethod + def _vectorize_inner(sch, block_rv, vec_lanes): + """Split the innermost loop to vec_lanes and vectorize.""" + block_loops = sch.get_loops(block_rv) + if len(block_loops) <= 1: + return + inner = block_loops[-1] + extent = get_extent(sch, inner) + if isinstance(extent, int): + if extent > vec_lanes: + _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.vectorize(vec_loop) + elif extent >= 2: + sch.vectorize(inner) + else: + _, vec_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.vectorize(vec_loop) + + @staticmethod + def _unroll_reduction_inner(sch, block_rv, vec_lanes): + """Split the reduction inner loop and annotate for unrolling.""" + block_loops = sch.get_loops(block_rv) + if len(block_loops) <= 1: + return + inner = block_loops[-1] + extent = get_extent(sch, inner) + if isinstance(extent, int) and extent <= vec_lanes: + return + _, inner_loop = sch.split(inner, factors=[None, vec_lanes]) + sch.annotate(inner_loop, ann_key="pragma_auto_unroll_max_step", ann_val=vec_lanes) + sch.annotate(inner_loop, ann_key="pragma_unroll_explicit", ann_val=1) diff --git a/tests/python/s_tir/dlight/test_cpu_reduction.py b/tests/python/s_tir/dlight/test_cpu_reduction.py new file mode 100644 index 000000000000..db8280a61a0f --- /dev/null +++ b/tests/python/s_tir/dlight/test_cpu_reduction.py @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""Tests for CPU DLight Reduction schedule rule.""" + +import pytest + +import tvm +import tvm.testing +from tvm import te, tirx, topi +from tvm.s_tir import dlight as dl +from tvm.s_tir.dlight.cpu import Reduction +from tvm.target import Target + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _llvm_target(): + return Target({"kind": "llvm"}) + + +def _rvv_target(): + return Target( + { + "kind": "llvm", + "mtriple": "riscv64-linux-gnu", + "mcpu": "generic-rv64", + "mabi": "lp64d", + "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c", "+v"], + } + ) + + +def _build_softmax(batch, features, fast=False): + A = te.placeholder((batch, features), dtype="float32", name="A") + B = topi.nn.fast_softmax(A, axis=1) if fast else topi.nn.softmax(A, axis=1) + func = te.create_prim_func([A, B]) + return tvm.IRModule({"main": func}) + + +def _apply_and_check(mod, target): + """Apply Reduction rule and verify it was applied.""" + rule = Reduction() + result = rule.apply(mod["main"], target, False) + assert result is not None, "Reduction rule should apply" + return result + + +# --------------------------------------------------------------------------- +# Test: schedule applicability +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fast", [False, True], ids=["softmax", "fast_softmax"]) +@pytest.mark.parametrize( + "batch,features", + [ + (1, 10), + (1, 128), + (14, 185), + (32, 256), + (64, 512), + (128, 1024), + (1, 30522), + ], +) +def test_reduction_applies(batch, features, fast): + """Reduction rule should apply to softmax/fast_softmax of various shapes.""" + mod = _build_softmax(batch, features, fast=fast) + target = _llvm_target() + _apply_and_check(mod, target) + + +# --------------------------------------------------------------------------- +# Test: scheduled TIR structure +# --------------------------------------------------------------------------- + + +def test_softmax_schedule_structure(): + """Verify the scheduled TIR has expected structure: + - parallel on batch axis + - vectorized innermost loops for injective blocks + - split+unroll for reduction blocks + """ + mod = _build_softmax(14, 185, fast=False) + target = _llvm_target() + sch = _apply_and_check(mod, target) + scheduled_mod = sch.mod + + # Check that tirx.is_scheduled is NOT set (only set by ApplyDefaultSchedule) + # but the schedule should be valid + assert scheduled_mod is not None + + # Verify via ApplyDefaultSchedule path + with target: + scheduled = dl.ApplyDefaultSchedule(Reduction())(mod) + func = scheduled["main"] + + # Check tirx.is_scheduled is set + assert func.attrs and func.attrs.get("tirx.is_scheduled", False) + + +def test_fast_softmax_schedule_structure(): + """fast_softmax should keep T_fast_exp as a separate vectorizable block.""" + mod = _build_softmax(14, 185, fast=True) + target = _llvm_target() + sch = _apply_and_check(mod, target) + script = str(sch.mod) + + # fast_exp block should exist (not inlined) + assert "T_fast_exp" in script or "T_softmax_delta" in script + # Should have T.parallel + assert "T.parallel" in script + # Should have T.vectorized + assert "T.vectorized" in script + + +# --------------------------------------------------------------------------- +# Test: LLVM IR quality (cross-compile to RISC-V RVV) +# --------------------------------------------------------------------------- + + +def _codegen_llvm_ir(mod, target): + """Lower and codegen to LLVM IR (no linking).""" + bound = tirx.transform.BindTarget(target.with_host(target))(mod) + pipeline = tirx.get_tir_pipeline("default") + lowered = pipeline(bound) + from tvm.tirx.build import split_host_device_mods + + host_mod, _ = split_host_device_mods(lowered) + host_mod = tirx.pipeline.finalize_host_passes()(host_mod) + built = tvm.target.codegen.build_module(host_mod, target) + return built.inspect_source("ll") + + +def _codegen_asm(mod, target): + """Lower and codegen to assembly (no linking).""" + bound = tirx.transform.BindTarget(target.with_host(target))(mod) + pipeline = tirx.get_tir_pipeline("default") + lowered = pipeline(bound) + from tvm.tirx.build import split_host_device_mods + + host_mod, _ = split_host_device_mods(lowered) + host_mod = tirx.pipeline.finalize_host_passes()(host_mod) + built = tvm.target.codegen.build_module(host_mod, target) + return built.inspect_source("s") + + +@pytest.mark.parametrize("fast", [False, True], ids=["softmax", "fast_softmax"]) +def test_rvv_code_size_reduction(fast): + """Scheduled RVV code should be smaller than unscheduled. + + The original issue (apache/tvm#18569) shows RVV softmax is 1.34x slower + than scalar, partly due to LLVM generating bloated code with excessive + unrolling. The schedule should reduce code size significantly. + """ + target = _rvv_target() + mod = _build_softmax(14, 185, fast=fast) + + # Unscheduled + ir_unsched = _codegen_llvm_ir(mod, target) + n_unsched = len(ir_unsched.splitlines()) + + # Scheduled + with target: + mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod) + ir_sched = _codegen_llvm_ir(mod_sched, target) + n_sched = len(ir_sched.splitlines()) + + # Scheduled should be meaningfully smaller (at least 30% reduction) + ratio = n_sched / n_unsched + assert ratio < 0.75, ( + f"Expected >=25% code reduction, got {(1 - ratio) * 100:.1f}% " + f"({n_unsched} -> {n_sched} lines)" + ) + + +def test_rvv_fast_softmax_vectorizes_exp(): + """fast_softmax + schedule should produce RVV vector instructions + for the polynomial exp approximation (no scalar exp calls).""" + target = _rvv_target() + mod = _build_softmax(14, 185, fast=True) + with target: + mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod) + ir = _codegen_llvm_ir(mod_sched, target) + + # Should have zero scalar exp calls (fast_exp uses polynomial) + scalar_exp = sum(1 for line in ir.splitlines() if "llvm.exp.f32" in line) + assert scalar_exp == 0, f"Expected 0 scalar exp calls, got {scalar_exp}" + + # Should have scalable vector operations + n_svec = ir.count(" 0, "Expected scalable vector operations in LLVM IR" + + +def test_rvv_asm_instruction_reduction(): + """Scheduled RVV assembly should have fewer total instructions + than both unscheduled RVV and scalar RV.""" + rvv = _rvv_target() + rv = Target( + { + "kind": "llvm", + "mtriple": "riscv64-linux-gnu", + "mcpu": "generic-rv64", + "mabi": "lp64d", + "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c"], + } + ) + + mod = _build_softmax(14, 185, fast=True) + + # Scalar baseline + asm_rv = _codegen_asm(mod, rv) + n_rv = len( + [ + line + for line in asm_rv.splitlines() + if line.strip() and not line.strip().startswith((".", "#", "/")) + ] + ) + + # RVV unscheduled + asm_rvv = _codegen_asm(mod, rvv) + n_rvv = len( + [ + line + for line in asm_rvv.splitlines() + if line.strip() and not line.strip().startswith((".", "#", "/")) + ] + ) + + # RVV scheduled + with rvv: + mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod) + asm_sched = _codegen_asm(mod_sched, rvv) + n_sched = len( + [ + line + for line in asm_sched.splitlines() + if line.strip() and not line.strip().startswith((".", "#", "/")) + ] + ) + + # Scheduled should be smaller than both unscheduled RVV and scalar + assert n_sched < n_rvv, ( + f"Scheduled ({n_sched}) should have fewer instructions than unscheduled RVV ({n_rvv})" + ) + assert n_sched <= n_rv * 1.1, ( + f"Scheduled ({n_sched}) should not be much larger than scalar RV ({n_rv})" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])