From d5a1880193d6a37aa5e8f10f5e818dd8a22fc4a3 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 14 Apr 2026 19:18:12 +0800 Subject: [PATCH 1/3] [#18424][S-TIR] Fix Segfault when applying Parallel during TIR schedule rewriting --- .../rewrite_parallel_vectorize_unroll.cc | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index ebaa58660e3a..e574e5326ca9 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -419,43 +419,47 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} bool Apply(const Schedule& sch) final { - s_tir::ParsedAnnotation parsed_root; - s_tir::SBlockRV root_rv{ffi::UnsafeInit()}; - while (s_tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { - for (s_tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { - ffi::Array loop_rvs = sch->GetLoops(block_rv); - if (loop_rvs.empty()) { - continue; - } - s_tir::ParsedAnnotation parsed = parsed_root; - s_tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); - const int loops_num = loop_rvs.size(); - try { - if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { - // Fuse, split, vectorize and parallelize - s_tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); - } else { - // Parallel - if (parsed.num_parallel_loops > 0) { - s_tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + try { + s_tir::ParsedAnnotation parsed_root; + s_tir::SBlockRV root_rv{ffi::UnsafeInit()}; + while (s_tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { + for (s_tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { + ffi::Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + s_tir::ParsedAnnotation parsed = parsed_root; + s_tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + const int loops_num = loop_rvs.size(); + try { + if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { + // Fuse, split, vectorize and parallelize + s_tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); + } else { + // Parallel + if (parsed.num_parallel_loops > 0) { + s_tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + s_tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } } - // Vectorize - if (parsed.num_vectorize_loops > 0) { - s_tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + TVM_FFI_ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + s_tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); } + } catch (const s_tir::ScheduleError& e) { + DLOG(WARNING) << "Failed to apply parallelization/vectorization: " << e.what(); + return false; } - // AutoUnroll - if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { - TVM_FFI_ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); - int unroll_explicit = parsed.unroll_explicit != -1; - int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; - s_tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); - } - } catch (const s_tir::ScheduleError& e) { - DLOG(WARNING) << "Failed to apply parallelization/vectorization: " << e.what(); - return false; } } + } catch (const runtime::Error&) { + return false; } return true; } From 35cec9b527c92cbe9e6dcf3236bb8226c35c7a76 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 14 Apr 2026 20:10:31 +0800 Subject: [PATCH 2/3] revert try-catch and fix unsafe dereferences of IntImmNode --- .../rewrite_parallel_vectorize_unroll.cc | 91 +++++++++++-------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index e574e5326ca9..201e51477692 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -213,6 +213,7 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, // (vectorizable) axes for (const BufferRegion& access : buffer_access) { int fusible = 0; + bool can_analyze_contiguous_access = true; std::vector strides; // get strides for each loop var for (const StmtSRef& loop_sref : loop_srefs) { @@ -226,10 +227,22 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, stride = coef * buffer_stride; break; } - buffer_stride *= access->buffer->shape[i].as()->value; + const auto* shape = access->buffer->shape[i].as(); + if (shape == nullptr) { + can_analyze_contiguous_access = false; + break; + } + buffer_stride *= shape->value; + } + if (!can_analyze_contiguous_access) { + break; } strides.push_back(stride); } + if (!can_analyze_contiguous_access) { + max_fusible = 0; + continue; + } int prev_used_iter = -1; // check the number of fusible loops for (int i = strides.size() - 1; i >= 0; i--) { @@ -246,9 +259,11 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, prev_used_iter = i; } else { // contiguous memory access - const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); - int64_t prev_used_iter_extent = prev_loop->extent.as()->value; - if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + const int64_t* prev_used_iter_extent = GetLoopIntExtent(loop_srefs[prev_used_iter]); + if (prev_used_iter_extent == nullptr) { + break; + } + if (strides[i] == strides[prev_used_iter] * (*prev_used_iter_extent)) { fusible++; prev_used_iter = i; } else { @@ -419,47 +434,43 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} bool Apply(const Schedule& sch) final { - try { - s_tir::ParsedAnnotation parsed_root; - s_tir::SBlockRV root_rv{ffi::UnsafeInit()}; - while (s_tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { - for (s_tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { - ffi::Array loop_rvs = sch->GetLoops(block_rv); - if (loop_rvs.empty()) { - continue; - } - s_tir::ParsedAnnotation parsed = parsed_root; - s_tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); - const int loops_num = loop_rvs.size(); - try { - if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { - // Fuse, split, vectorize and parallelize - s_tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); - } else { - // Parallel - if (parsed.num_parallel_loops > 0) { - s_tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); - } - // Vectorize - if (parsed.num_vectorize_loops > 0) { - s_tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); - } + s_tir::ParsedAnnotation parsed_root; + s_tir::SBlockRV root_rv{ffi::UnsafeInit()}; + while (s_tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { + for (s_tir::SBlockRV block_rv : sch->GetChildBlocks(root_rv)) { + ffi::Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + s_tir::ParsedAnnotation parsed = parsed_root; + s_tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + const int loops_num = loop_rvs.size(); + try { + if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { + // Fuse, split, vectorize and parallelize + s_tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); + } else { + // Parallel + if (parsed.num_parallel_loops > 0) { + s_tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); } - // AutoUnroll - if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { - TVM_FFI_ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); - int unroll_explicit = parsed.unroll_explicit != -1; - int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; - s_tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); + // Vectorize + if (parsed.num_vectorize_loops > 0) { + s_tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); } - } catch (const s_tir::ScheduleError& e) { - DLOG(WARNING) << "Failed to apply parallelization/vectorization: " << e.what(); - return false; } + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + TVM_FFI_ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + s_tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); + } + } catch (const s_tir::ScheduleError& e) { + DLOG(WARNING) << "Failed to apply parallelization/vectorization: " << e.what(); + return false; } } - } catch (const runtime::Error&) { - return false; } return true; } From 6f9b051fb99d417029801bb972f1d6c9a1b56723 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 14 Apr 2026 20:56:02 +0800 Subject: [PATCH 3/3] add test case --- .../rewrite_parallel_vectorize_unroll.cc | 2 +- ...tproc_rewrite_parallel_vectorize_unroll.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 201e51477692..d05c9a32cb18 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -241,7 +241,7 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, } if (!can_analyze_contiguous_access) { max_fusible = 0; - continue; + break; } int prev_used_iter = -1; // check the number of fusible loops diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index f70f16ea6c45..e7baabb1e61c 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -181,6 +181,24 @@ def after_postproc_add( add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] +@T.prim_func +def before_postproc_dynamic_shape_vectorize( + a: T.handle, + b: T.handle, +) -> None: + n = T.int64() + A = T.match_buffer(a, (n,), dtype="float32") + B = T.match_buffer(b, (n,), dtype="float32") + with T.block("root"): + T.block_attr({"meta_schedule.vectorize": 64}) + for i in T.serial(0, n): + with T.block("copy"): + vi = T.axis.spatial(n, i) + T.reads(A[vi]) + T.writes(B[vi]) + B[vi] = A[vi] + + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable @@ -269,5 +287,11 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo assert_structural_equal_ignore_global_symbol(mod["main"], expected) +def test_rewrite_parallel_vectorize_unroll_dynamic_shape_no_crash(): + sch = Schedule(before_postproc_dynamic_shape_vectorize) + rule = RewriteParallelVectorizeUnroll() + assert rule.apply(sch) + + if __name__ == "__main__": tvm.testing.main()