Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv,
// (vectorizable) axes
for (const BufferRegion& access : buffer_access) {
int fusible = 0;
bool can_analyze_contiguous_access = true;
std::vector<int64_t> strides;
// get strides for each loop var
for (const StmtSRef& loop_sref : loop_srefs) {
Expand All @@ -226,10 +227,22 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv,
stride = coef * buffer_stride;
break;
}
buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
const auto* shape = access->buffer->shape[i].as<IntImmNode>();
if (shape == nullptr) {
can_analyze_contiguous_access = false;
break;
}
buffer_stride *= shape->value;
}
if (!can_analyze_contiguous_access) {
break;
}
strides.push_back(stride);
}
if (!can_analyze_contiguous_access) {
max_fusible = 0;
break;
}
int prev_used_iter = -1;
// check the number of fusible loops
for (int i = strides.size() - 1; i >= 0; i--) {
Expand All @@ -246,9 +259,11 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv,
prev_used_iter = i;
} else {
// contiguous memory access
const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value;
if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
const int64_t* prev_used_iter_extent = GetLoopIntExtent(loop_srefs[prev_used_iter]);
if (prev_used_iter_extent == nullptr) {
break;
}
if (strides[i] == strides[prev_used_iter] * (*prev_used_iter_extent)) {
fusible++;
prev_used_iter = i;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@ def after_postproc_add(
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4]


@T.prim_func
def before_postproc_dynamic_shape_vectorize(
a: T.handle,
b: T.handle,
) -> None:
n = T.int64()
A = T.match_buffer(a, (n,), dtype="float32")
B = T.match_buffer(b, (n,), dtype="float32")
with T.block("root"):
T.block_attr({"meta_schedule.vectorize": 64})
for i in T.serial(0, n):
with T.block("copy"):
vi = T.axis.spatial(n, i)
T.reads(A[vi])
T.writes(B[vi])
B[vi] = A[vi]


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable

Expand Down Expand Up @@ -269,5 +287,11 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo
assert_structural_equal_ignore_global_symbol(mod["main"], expected)


def test_rewrite_parallel_vectorize_unroll_dynamic_shape_no_crash():
sch = Schedule(before_postproc_dynamic_shape_vectorize)
rule = RewriteParallelVectorizeUnroll()
assert rule.apply(sch)


if __name__ == "__main__":
tvm.testing.main()
Loading