diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index ebaa58660e3a..d05c9a32cb18 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -213,6 +213,7 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, // (vectorizable) axes for (const BufferRegion& access : buffer_access) { int fusible = 0; + bool can_analyze_contiguous_access = true; std::vector strides; // get strides for each loop var for (const StmtSRef& loop_sref : loop_srefs) { @@ -226,10 +227,22 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, stride = coef * buffer_stride; break; } - buffer_stride *= access->buffer->shape[i].as()->value; + const auto* shape = access->buffer->shape[i].as(); + if (shape == nullptr) { + can_analyze_contiguous_access = false; + break; + } + buffer_stride *= shape->value; + } + if (!can_analyze_contiguous_access) { + break; } strides.push_back(stride); } + if (!can_analyze_contiguous_access) { + max_fusible = 0; + break; + } int prev_used_iter = -1; // check the number of fusible loops for (int i = strides.size() - 1; i >= 0; i--) { @@ -246,9 +259,11 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, prev_used_iter = i; } else { // contiguous memory access - const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); - int64_t prev_used_iter_extent = prev_loop->extent.as()->value; - if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + const int64_t* prev_used_iter_extent = GetLoopIntExtent(loop_srefs[prev_used_iter]); + if (prev_used_iter_extent == nullptr) { + break; + } + if (strides[i] == strides[prev_used_iter] * (*prev_used_iter_extent)) { fusible++; prev_used_iter = i; } else { diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index f70f16ea6c45..e7baabb1e61c 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -181,6 +181,24 @@ def after_postproc_add( add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] +@T.prim_func +def before_postproc_dynamic_shape_vectorize( + a: T.handle, + b: T.handle, +) -> None: + n = T.int64() + A = T.match_buffer(a, (n,), dtype="float32") + B = T.match_buffer(b, (n,), dtype="float32") + with T.block("root"): + T.block_attr({"meta_schedule.vectorize": 64}) + for i in T.serial(0, n): + with T.block("copy"): + vi = T.axis.spatial(n, i) + T.reads(A[vi]) + T.writes(B[vi]) + B[vi] = A[vi] + + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable @@ -269,5 +287,11 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo assert_structural_equal_ignore_global_symbol(mod["main"], expected) +def test_rewrite_parallel_vectorize_unroll_dynamic_shape_no_crash(): + sch = Schedule(before_postproc_dynamic_shape_vectorize) + rule = RewriteParallelVectorizeUnroll() + assert rule.apply(sch) + + if __name__ == "__main__": tvm.testing.main()