Skip to content

Add TMA scheduler for outer2D reduction#5966

Open
tbqh wants to merge 4 commits intomainfrom
tbqh/auto_outer_reduce_tma
Open

Add TMA scheduler for outer2D reduction#5966
tbqh wants to merge 4 commits intomainfrom
tbqh/auto_outer_reduce_tma

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Feb 17, 2026

Add auto-scheduler for TMA outer reduction. Similar schedule to PR#5926. Fixed block dims, TMA tile sizes, and unroll factor. Always target grid reduction with a very simple heuristic.

2026-02-17_05-35

@tbqh tbqh requested a review from liqiangxl February 17, 2026 11:54
@github-actions
Copy link

github-actions bot commented Feb 17, 2026

Review updated until commit 881b1e1

Description

  • Add new TMA scheduler for outer reduction operations using 2D thread blocks

  • Implement outer TMA scheduling with fixed parameters (bdimx=32, bdimy=16, TMA tiles=128x128)

  • Add auto-scheduler integration that tries outer TMA before inner TMA and non-TMA fallbacks

  • Include comprehensive tests for outer TMA reduction with various tensor sizes

Changes walkthrough

Relevant files
Enhancement
reduction.cpp
Integrate outer TMA scheduler into main reduction scheduler

csrc/scheduler/reduction.cpp

  • Added include for reduction_outer_tma.h header
  • Implemented mayUseTmaOuter function with TMA eligibility checks for
    outer reductions
  • Modified computeHeuristics to prioritize outer TMA scheduler over
    inner TMA
  • Updated schedule function to handle new TmaOuterReductionParams type
  • +63/-5   
    reduction_outer_tma.cpp
    Implement outer TMA reduction scheduler with 2D tiling     

    csrc/scheduler/reduction_outer_tma.cpp

  • Implemented getReductionHeuristics with fixed TMA parameters
    (bdimx=32, bdimy=16, tiles=128x128)
  • Implemented scheduleReduction with 10-phase TMA scheduling algorithm
  • Uses 2D TMA tiling, grid parallelization, and rFactor for outer
    reductions
  • Propagates transformations and parallelization across all tensors
  • +231/-0 
    reduction_outer_tma.h
    Define outer TMA reduction parameters and function declarations

    csrc/scheduler/reduction_outer_tma.h

  • Defined TmaOuterReductionParams class with TMA-specific parameters
  • Declared getReductionHeuristics and scheduleReduction functions
  • Specified thread block dimensions, TMA tiles, and unroll factors
  • +52/-0   
    Tests
    test_reduction.cpp
    Add comprehensive tests for outer TMA reduction scheduler

    tests/cpp/test_reduction.cpp

  • Added outer TMA header include and updated TMA parameter detection
  • Modified manual test to only check iteration dimension alignment
  • Added TmaOuterReductionTest class for auto-scheduled outer TMA testing
  • Tests various tensor sizes with automatic TMA scheduler selection
  • +107/-6 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Heuristics Need Tuning

    The TODO comment on line 30-31 explicitly states that the heuristics are "stubbed out based on manual test" and need proper tuning. The scheduler uses fixed parameters (bdimx=32, bdimy=16, tma_tile_i=128, tma_tile_r=128) without adaptive selection based on problem size, memory bandwidth, or other performance factors. This could lead to suboptimal performance across different problem sizes.

    // TODO: These heuristics are stubbed out based on the manual test
    // (TmaOuterReductionManualTest::Basic). They need proper tuning.
    Redundant Alignment Check

    The mayUseTmaOuter function duplicates the alignment check logic from mayUseTma function (lines 282-284 vs lines 241-244). Both check if vectorize_factor * max_dtype_size_bit % 128 != 0. This duplication could lead to maintenance issues and inconsistent behavior if one is updated without the other.

    // TMA requires 16-byte alignment (128 bits) for memory transactions
    if (props.vectorize_factor * props.max_dtype_size_bit % 128 != 0) {
      return false;
    }
    Hard-coded Grid Dimension Calculation

    The grid dimension calculation (lines 50-51) uses a hard-coded formula that clamps lastPow2(outer_size / 256) to [1, 8]. This arbitrary choice of 256 and 8 as limits may not be optimal across different GPU architectures, problem sizes, or memory configurations. The calculation should be more robust and potentially configurable.

    int64_t grdim = std::max<int64_t>(
        1, std::min<int64_t>(8, scheduler_utils::lastPow2(outer_size / 256)));

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 17, 2026

    Greptile Summary

    This PR introduces a new TMA-based auto-scheduler for outer (non-fastest-dim) reductions, complementing the existing inner TMA scheduler from PR#5926. The scheduler uses a fixed 2D thread block (32×16), fixed 128×128 TMA tiles, and a simple heuristic for grid-level parallelism (clamped power-of-2 of outer_size / 256). The dispatch logic in reduction.cpp now tries the outer TMA path first, then the inner TMA path, then falls back to the non-TMA scheduler.

    Key changes:

    • New TmaOuterReductionParams params class and outer_tma::getReductionHeuristics / scheduleReduction functions in reduction_outer_tma.{h,cpp}.
    • mayUseTmaOuter guard in reduction.cpp that restricts eligibility to SM≥9, outer reductions, single-input fusions with at least 16 KB of reduction data and a 16-byte-aligned innermost stride.
    • New TmaOuterReductionTest parameterized test suite covering a range of {outer_size, iter_size} combinations.

    Critical issue: Both mayUseTma (line 242) and mayUseTmaOuter (line 282) in reduction.cpp reference props.max_dtype_size_bit, which does not exist in reduction_scheduler_utils::FusionRuntimeProperties. The correct field is max_dtype_size_bit_for_vectorization. The earlier "Fix compile error" commit (b4e784c) only corrected the variable name from propprops but left the wrong field name, meaning the code still does not compile.

    Minor issue: redu_unroll_factor is computed in getReductionHeuristics and stored in the params struct, but scheduleReduction never reads it back — it re-derives the equivalent value implicitly via redu_tv->split(2, bdimy).

    Confidence Score: 2/5

    • Not safe to merge: the PR contains a compile error (props.max_dtype_size_bit does not exist in FusionRuntimeProperties) that will prevent the build from succeeding.
    • The compile error on lines 242 and 282 of reduction.cpp — using props.max_dtype_size_bit instead of props.max_dtype_size_bit_for_vectorization — is a blocker. An earlier commit attempted to fix a related typo (propprops) but missed the incorrect field name. All other parts of the scheduler logic (2D tiling, rFactor, parallelization propagation) follow established patterns from the inner TMA scheduler and look structurally sound.
    • csrc/scheduler/reduction.cpp lines 242 and 282 require fixing before this can build.

    Important Files Changed

    Filename Overview
    csrc/scheduler/reduction.cpp Adds outer TMA scheduler dispatch. Contains a compile error: props.max_dtype_size_bit (used in both mayUseTma and mayUseTmaOuter) is not a field of reduction_scheduler_utils::FusionRuntimeProperties, which only has max_dtype_size_bit_for_vectorization.
    csrc/scheduler/reduction_outer_tma.cpp New outer TMA scheduler implementation. The multi-phase scheduling logic (TMA tiling, grid reduction, rFactor, parallelization) is consistent with the inner TMA scheduler. Minor: redu_unroll_factor is computed and stored in params but never read back in scheduleReduction, which re-derives the split from bdimy directly.
    csrc/scheduler/reduction_outer_tma.h New params class and function declarations. Clean and consistent with the existing TmaInnerReductionParams pattern. No issues found.
    tests/cpp/test_reduction.cpp Adds TmaOuterReductionTest parameterized suite and updates isTmaParams helper. The expectOuterTmaUsed helper closely mirrors mayUseTmaOuter conditions. Test only asserts scheduler selection for large outer sizes (>= 4096 elements for float32) and skips rather than asserting for small sizes.

    Flowchart

    flowchart TD
        A[ReductionScheduler::computeHeuristics] --> B{tma_enabled?}
        B -- No --> F[non_tma::getReductionHeuristics]
        B -- Yes --> C{mayUseTmaOuter?}
        C -- "No: inner reduction / small size / misaligned / multi-input" --> D{mayUseTma?}
        C -- "Yes: outer reduction, large enough, SM>=9, single input, aligned" --> E[outer_tma::getReductionHeuristics]
        D -- No --> F
        D -- Yes --> G[tma::getReductionHeuristics]
        E --> H[TmaOuterReductionParams]
        G --> I[TmaInnerReductionParams]
        F --> J[ReductionParams]
    
        H --> K[ReductionScheduler::schedule]
        I --> K
        J --> K
    
        K --> L{dynamic_cast}
        L -- TmaOuterReductionParams --> M[outer_tma::scheduleReduction]
        L -- TmaInnerReductionParams --> N[tma::scheduleReduction]
        L -- ReductionParams --> O[non_tma::scheduleReduction]
    
        subgraph "outer_tma::scheduleReduction"
            M --> P1[Phase 1: cacheInputs - set CpAsyncBulkTensorTile]
            P1 --> P2[Phase 2: TMA 2D tiling on smem TV - split R,I by tma_tile_r,tma_tile_i]
            P2 --> P3[Phase 3: Propagate TMA transforms to all TVs]
            P3 --> P4[Phase 4: Parallelize TMA TV - BIDy/Serial/Bulk/BIDx/Bulk]
            P4 --> P5[Phase 5: Sub-split TMA tiles into TIDx,TIDy]
            P5 --> P6[Phase 6: Parallelize reduction TV]
            P6 --> P7[Phase 7: rFactor non-thread reduction axes for grid reduction]
            P7 --> P8[Phase 8: Propagate to non-TMA TVs]
            P8 --> P9[Phase 9: propagateParallelization with iter-grouped reduction]
            P9 --> P10[Phase 10: inlineMost + refineCachePolicy]
        end
    
    Loading

    Last reviewed commit: 881b1e1

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    }

    // TMA requires 16-byte alignment (128 bits) for memory transactions
    if (props.vectorize_factor * props.max_dtype_size_bit % 128 != 0) {
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    max_dtype_size_bit field does not exist — compile error

    reduction_scheduler_utils::FusionRuntimeProperties has no field named max_dtype_size_bit; the correct field is max_dtype_size_bit_for_vectorization. This causes a compilation failure. The intent of the check (16-byte TMA alignment) matches max_dtype_size_bit_for_vectorization.

    Suggested change
    if (props.vectorize_factor * props.max_dtype_size_bit % 128 != 0) {
    if (props.vectorize_factor * props.max_dtype_size_bit_for_vectorization % 128 != 0) {

    }

    // TMA requires 16-byte alignment (128 bits) for memory transactions
    if (props.vectorize_factor * props.max_dtype_size_bit % 128 != 0) {
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    max_dtype_size_bit field does not exist — compile error

    Same issue as in mayUseTma above. reduction_scheduler_utils::FusionRuntimeProperties has no max_dtype_size_bit field. The field should be max_dtype_size_bit_for_vectorization.

    Suggested change
    if (props.vectorize_factor * props.max_dtype_size_bit % 128 != 0) {
    if (props.vectorize_factor * props.max_dtype_size_bit_for_vectorization % 128 != 0) {

    Comment on lines +45 to +60
    const int64_t redu_unroll_factor = tma_tile_r / bdimy;

    // Grid dimension for parallelizing the outer reduction across CTAs.
    // Modeled after the manual test: clamp lastPow2(outer_size / 256) to [1, 8].
    const int64_t outer_size = props.total_reduction_numel;
    int64_t grdim = std::max<int64_t>(
    1, std::min<int64_t>(8, scheduler_utils::lastPow2(outer_size / 256)));

    auto params = std::make_unique<TmaOuterReductionParams>();
    params->bdimx = bdimx;
    params->bdimy = bdimy;
    params->tma_tile_i = tma_tile_i;
    params->tma_tile_r = tma_tile_r;
    params->iter_unroll_factor = iter_unroll_factor;
    params->redu_unroll_factor = redu_unroll_factor;
    params->grdim = grdim;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    redu_unroll_factor computed but never used in scheduleReduction

    redu_unroll_factor is computed in getReductionHeuristics and stored in params->redu_unroll_factor, but scheduleReduction never reads rparams->redu_unroll_factor. Instead, scheduleReduction re-derives the same split factor implicitly via redu_tv->split(2, bdimy) (giving tma_tile_r / bdimy as the outer = unroll count). In contrast, iter_unroll_factor is explicitly re-derived and used (redu_tv->split(4, iter_unroll_factor)). Consider either removing redu_unroll_factor from the params struct and the heuristics computation, or making scheduleReduction read rparams->redu_unroll_factor explicitly for consistency.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants