Skip to content

skip bcast domains in transpose scheduler mapped domain check#5967

Draft
liqiangxl wants to merge 2 commits intomainfrom
llu/use_pointwise_not_transpose
Draft

skip bcast domains in transpose scheduler mapped domain check#5967
liqiangxl wants to merge 2 commits intomainfrom
llu/use_pointwise_not_transpose

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 17, 2026

test_moe.py has a segmented fusion with t0[b0, i1, i2] and t1[i1, b2], it uses transpose scheduler due to t1's inner most dim is [i1] which is different from t0's inner most dim [i2].
Ideally, tranpose scheduler should reject it and pointwise scheudler should be used.

This PR revised hasAtLeastTwoValidGroups by filtering out bcast domains and check if the shorter one is a sub-sequence of the longer one. e.g. filtered t0 is [i1,i2], filtered t1 is [i1]. t1 is a sub-sequence of t0 and they are considered as mapped and tranpose scheduler will reject it. This check ensures there is no actual permutation of concretized domains and pointwise should be used instead of transpose.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

Description

  • Skip broadcast domains in transpose scheduler domain mapping check

  • Allow flexible subsequence matching for filtered domains

  • Enable pointwise scheduler for cases with broadcast dimensions

  • Update validation to check original domains for broadcasts

Changes walkthrough

Relevant files
Enhancement
domain_map.cpp
Skip broadcast domains in transpose scheduler                       

csrc/scheduler/tools/domain_map.cpp

  • Modified hasAtLeastTwoValidGroups to filter out broadcast dimensions
  • Changed from strict equality to flexible subsequence matching
  • Updated comments to explain broadcast dimension handling
  • Fixed validation check to use original loop domains
  • +32/-11 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Algorithm Correctness

    The new subsequence-based mapping algorithm is more permissive than the previous exact equality check. While this allows broadcast dimensions to be skipped as intended, verify that this doesn't incorrectly accept invalid mappings. The algorithm assumes that if all elements of the shorter sequence can be found in the longer sequence in order, then the domains are compatible. This seems correct for the broadcast use case but should be validated.

    // Filter out reduction and broadcast dimensions before comparing.
    // - the shorter must be a subsequence of the longer (same
    //   order), e.g. longer [i0,i1,i2] allows shorter [i0,i1], [i1,i2], [i0,i2];
    //   disallows [i1,i0], [i2,i1], [i2,i0].
    auto ref1_filtered =
        ref1_loop | TensorDomain::kNoReductions | TensorDomain::kNoBroadcasts;
    auto ref2_filtered =
        ref2_loop | TensorDomain::kNoReductions | TensorDomain::kNoBroadcasts;
    
    const auto n1 = std::ranges::distance(ref1_filtered);
    const auto n2 = std::ranges::distance(ref2_filtered);
    auto shorter = (n1 <= n2) ? ref1_filtered : ref2_filtered;
    auto longer = (n1 <= n2) ? ref2_filtered : ref1_filtered;
    auto it = std::ranges::begin(longer);
    auto end = std::ranges::end(longer);
    bool all_mapped = true;
    for (IterDomain* id_s : shorter) {
      it = std::ranges::find_if(it, end, [&](IterDomain* id_l) {
        return ca_map.areMapped(id_s, id_l, IdMappingMode::PERMISSIVE);
      });
      if (it == end) {
        all_mapped = false;
        break;
      }
      ++it;
    }
    Performance Impact

    The new algorithm uses nested search (std::ranges::find_if in a loop) which could be O(n*m) complexity vs the previous O(n) exact match. Given that tensor domains are typically small, this may not be significant, but the performance impact should be evaluated, especially for large tensor domains.

    for (IterDomain* id_s : shorter) {
      it = std::ranges::find_if(it, end, [&](IterDomain* id_l) {
        return ca_map.areMapped(id_s, id_l, IdMappingMode::PERMISSIVE);
      });
      if (it == end) {
        all_mapped = false;
        break;
      }
      ++it;
    Validation Logic Correction

    The validation check was corrected to examine the original loop domains for broadcasts rather than the filtered domains. This is the correct fix since we want to ensure broadcasts exist in the original domains. Verify this change doesn't break any existing validation assumptions.

    const bool any_bcast =
        std::ranges::any_of(
            ref1_loop, [](IterDomain* id) { return id->isBroadcast(); }) ||
        std::ranges::any_of(
            ref2_loop, [](IterDomain* id) { return id->isBroadcast(); });

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant